Soft Actor-Critic¶
Table of Contents
Background¶
(Previously: Background for TD3)
Soft Actor Critic (SAC) is an algorithm that optimizes a stochastic policy in an off-policy way, forming a bridge between stochastic policy optimization and DDPG-style approaches. It isn’t a direct successor to TD3 (having been published roughly concurrently), but it incorporates the clipped double-Q trick, and due to the inherent stochasticity of the policy in SAC, it also winds up benefiting from something like target policy smoothing.
A central feature of SAC is entropy regularization. The policy is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy. This has a close connection to the exploration-exploitation trade-off: increasing entropy results in more exploration, which can accelerate learning later on. It can also prevent the policy from prematurely converging to a bad local optimum.
Quick Facts¶
- SAC is an off-policy algorithm.
- The version of SAC implemented here can only be used for environments with continuous action spaces.
- An alternate version of SAC, which slightly changes the policy update rule, can be implemented to handle discrete action spaces.
- The Spinning Up implementation of SAC does not support parallelization.
Key Equations¶
To explain Soft Actor Critic, we first have to introduce the entropy-regularized reinforcement learning setting. In entropy-regularized RL, there are slightly-different equations for value functions.
Entropy-Regularized Reinforcement Learning¶
Entropy is a quantity which, roughly speaking, says how random a random variable is. If a coin is weighted so that it almost always comes up heads, it has low entropy; if it’s evenly weighted and has a half chance of either outcome, it has high entropy.
Let be a random variable with probability mass or density function . The entropy of is computed from its distribution according to
In entropy-regularized reinforcement learning, the agent gets a bonus reward at each time step proportional to the entropy of the policy at that timestep. This changes the RL problem to:
where is the trade-off coefficient. (Note: we’re assuming an infinite-horizon discounted setting here, and we’ll do the same for the rest of this page.) We can now define the slightly-different value functions in this setting. is changed to include the entropy bonuses from every timestep:
is changed to include the entropy bonuses from every timestep except the first:
With these definitions, and are connected by:
and the Bellman equation for is
You Should Know
The way we’ve set up the value functions in the entropy-regularized setting is a little bit arbitrary, and actually we could have done it differently (eg make include the entropy bonus at the first timestep). The choice of definition may vary slightly across papers on the subject.
Soft Actor-Critic¶
SAC concurrently learns a policy and two Q-functions . There are two variants of SAC that are currently standard: one that uses a fixed entropy regularization coefficient , and another that enforces an entropy constraint by varying over the course of training. For simplicity, Spinning Up makes use of the version with a fixed entropy regularization coefficient, but the entropy-constrained variant is generally preferred by practitioners.
You Should Know
The SAC algorithm has changed a little bit over time. An older version of SAC also learns a value function in addition to the Q-functions; this page will focus on the modern version that omits the extra value function.
Learning Q. The Q-functions are learned in a similar way to TD3, but with a few key differences.
First, what’s similar?
- Like in TD3, both Q-functions are learned with MSBE minimization, by regressing to a single shared target.
- Like in TD3, the shared target is computed using target Q-networks, and the target Q-networks are obtained by polyak averaging the Q-network parameters over the course of training.
- Like in TD3, the shared target makes use of the clipped double-Q trick.
What’s different?
- Unlike in TD3, the target also includes a term that comes from SAC’s use of entropy regularization.
- Unlike in TD3, the next-state actions used in the target come from the current policy instead of a target policy.
- Unlike in TD3, there is no explicit target policy smoothing. TD3 trains a deterministic policy, and so it accomplishes smoothing by adding random noise to the next-state actions. SAC trains a stochastic policy, and so the noise from that stochasticity is sufficient to get a similar effect.
Before we give the final form of the Q-loss, let’s take a moment to discuss how the contribution from entropy regularization comes in. We’ll start by taking our recursive Bellman equation for the entropy-regularized from earlier, and rewriting it a little bit by using the definition of entropy:
The RHS is an expectation over next states (which come from the replay buffer) and next actions (which come from the current policy, and not the replay buffer). Since it’s an expectation, we can approximate it with samples:
You Should Know
We switch next action notation to , instead of , to highlight that the next actions have to be sampled fresh from the policy (whereas by contrast, and should come from the replay buffer).
SAC sets up the MSBE loss for each Q-function using this kind of sample approximation for the target. The only thing still undetermined here is which Q-function gets used to compute the sample backup: like TD3, SAC uses the clipped double-Q trick, and takes the minimum Q-value between the two Q approximators.
Putting it all together, the loss functions for the Q-networks in SAC are:
where the target is given by
Learning the Policy. The policy should, in each state, act to maximize the expected future return plus expected future entropy. That is, it should maximize , which we expand out into
The way we optimize the policy makes use of the reparameterization trick, in which a sample from is drawn by computing a deterministic function of state, policy parameters, and independent noise. To illustrate: following the authors of the SAC paper, we use a squashed Gaussian policy, which means that samples are obtained according to
You Should Know
This policy has two key differences from the policies we use in the other policy optimization algorithms:
1. The squashing function. The in the SAC policy ensures that actions are bounded to a finite range. This is absent in the VPG, TRPO, and PPO policies. It also changes the distribution: before the the SAC policy is a factored Gaussian like the other algorithms’ policies, but after the it is not. (You can still compute the log-probabilities of actions in closed form, though: see the paper appendix for details.)
2. The way standard deviations are parameterized. In VPG, TRPO, and PPO, we represent the log std devs with state-independent parameter vectors. In SAC, we represent the log std devs as outputs from the neural network, meaning that they depend on state in a complex way. SAC with state-independent log std devs, in our experience, did not work. (Can you think of why? Or better yet: run an experiment to verify?)
The reparameterization trick allows us to rewrite the expectation over actions (which contains a pain point: the distribution depends on the policy parameters) into an expectation over noise (which removes the pain point: the distribution now has no dependence on parameters):
To get the policy loss, the final step is that we need to substitute with one of our function approximators. Unlike in TD3, which uses (just the first Q approximator), SAC uses (the minimum of the two Q approximators). The policy is thus optimized according to
which is almost the same as the DDPG and TD3 policy optimization, except for the min-double-Q trick, the stochasticity, and the entropy term.
Exploration vs. Exploitation¶
SAC trains a stochastic policy with entropy regularization, and explores in an on-policy way. The entropy regularization coefficient explicitly controls the explore-exploit tradeoff, with higher corresponding to more exploration, and lower corresponding to more exploitation. The right coefficient (the one which leads to the stablest / highest-reward learning) may vary from environment to environment, and could require careful tuning.
At test time, to see how well the policy exploits what it has learned, we remove stochasticity and use the mean action instead of a sample from the distribution. This tends to improve performance over the original stochastic policy.
You Should Know
Our SAC implementation uses a trick to improve exploration at the start of training. For a fixed number of steps at the beginning (set with the start_steps
keyword argument), the agent takes actions which are sampled from a uniform random distribution over valid actions. After that, it returns to normal SAC exploration.
Documentation¶
You Should Know
In what follows, we give documentation for the PyTorch and Tensorflow implementations of SAC in Spinning Up. They have nearly identical function calls and docstrings, except for details relating to model construction. However, we include both full docstrings for completeness.
Documentation: PyTorch Version¶
-
spinup.
sac_pytorch
(env_fn, actor_critic=<MagicMock spec='str' id='140554319922904'>, ac_kwargs={}, seed=0, steps_per_epoch=4000, epochs=100, replay_size=1000000, gamma=0.99, polyak=0.995, lr=0.001, alpha=0.2, batch_size=100, start_steps=10000, update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000, logger_kwargs={}, save_freq=1)¶ Soft Actor-Critic (SAC)
Parameters: - env_fn – A function which creates a copy of the environment. The environment must satisfy the OpenAI Gym API.
- actor_critic –
The constructor method for a PyTorch Module with an
act
method, api
module, aq1
module, and aq2
module. Theact
method andpi
module should accept batches of observations as inputs, andq1
andq2
should accept a batch of observations and a batch of actions as inputs. When called,act
,q1
, andq2
should return:Call Output Shape Description act
(batch, act_dim) Numpy array of actions for eachobservation.q1
(batch,) Tensor containing one current estimateof Q* for the provided observationsand actions. (Critical: make sure toflatten this!)q2
(batch,) Tensor containing the other currentestimate of Q* for the provided observationsand actions. (Critical: make sure toflatten this!)Calling
pi
should return:Symbol Shape Description a
(batch, act_dim) Tensor containing actions from policygiven observations.logp_pi
(batch,) Tensor containing log probabilities ofactions ina
. Importantly: gradientsshould be able to flow back intoa
. - ac_kwargs (dict) – Any kwargs appropriate for the ActorCritic object you provided to SAC.
- seed (int) – Seed for random number generators.
- steps_per_epoch (int) – Number of steps of interaction (state-action pairs) for the agent and the environment in each epoch.
- epochs (int) – Number of epochs to run and train agent.
- replay_size (int) – Maximum length of replay buffer.
- gamma (float) – Discount factor. (Always between 0 and 1.)
- polyak (float) –
Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to:
where is polyak. (Always between 0 and 1, usually close to 1.)
- lr (float) – Learning rate (used for both policy and value learning).
- alpha (float) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.)
- batch_size (int) – Minibatch size for SGD.
- start_steps (int) – Number of steps for uniform-random action selection, before running real policy. Helps exploration.
- update_after (int) – Number of env interactions to collect before starting to do gradient descent updates. Ensures replay buffer is full enough for useful updates.
- update_every (int) – Number of env interactions that should elapse between gradient descent updates. Note: Regardless of how long you wait between updates, the ratio of env steps to gradient steps is locked to 1.
- num_test_episodes (int) – Number of episodes to test the deterministic policy at the end of each epoch.
- max_ep_len (int) – Maximum length of trajectory / episode / rollout.
- logger_kwargs (dict) – Keyword args for EpochLogger.
- save_freq (int) – How often (in terms of gap between epochs) to save the current policy and value function.
Saved Model Contents: PyTorch Version¶
The PyTorch saved model can be loaded with ac = torch.load('path/to/model.pt')
, yielding an actor-critic object (ac
) that has the properties described in the docstring for sac_pytorch
.
You can get actions from this model with
actions = ac.act(torch.as_tensor(obs, dtype=torch.float32))
Documentation: Tensorflow Version¶
-
spinup.
sac_tf1
(env_fn, actor_critic=<function mlp_actor_critic>, ac_kwargs={}, seed=0, steps_per_epoch=4000, epochs=100, replay_size=1000000, gamma=0.99, polyak=0.995, lr=0.001, alpha=0.2, batch_size=100, start_steps=10000, update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000, logger_kwargs={}, save_freq=1)¶ Soft Actor-Critic (SAC)
Parameters: - env_fn – A function which creates a copy of the environment. The environment must satisfy the OpenAI Gym API.
- actor_critic –
A function which takes in placeholder symbols for state,
x_ph
, and action,a_ph
, and returns the main outputs from the agent’s Tensorflow computation graph:Symbol Shape Description mu
(batch, act_dim) Computes mean actions from policygiven states.pi
(batch, act_dim) Samples actions from policy givenstates.logp_pi
(batch,) Gives log probability, according tothe policy, of the action sampled bypi
. Critical: must be differentiablewith respect to policy parameters allthe way through action sampling.q1
(batch,) Gives one estimate of Q* forstates inx_ph
and actions ina_ph
.q2
(batch,) Gives another estimate of Q* forstates inx_ph
and actions ina_ph
. - ac_kwargs (dict) – Any kwargs appropriate for the actor_critic function you provided to SAC.
- seed (int) – Seed for random number generators.
- steps_per_epoch (int) – Number of steps of interaction (state-action pairs) for the agent and the environment in each epoch.
- epochs (int) – Number of epochs to run and train agent.
- replay_size (int) – Maximum length of replay buffer.
- gamma (float) – Discount factor. (Always between 0 and 1.)
- polyak (float) –
Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to:
where is polyak. (Always between 0 and 1, usually close to 1.)
- lr (float) – Learning rate (used for both policy and value learning).
- alpha (float) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.)
- batch_size (int) – Minibatch size for SGD.
- start_steps (int) – Number of steps for uniform-random action selection, before running real policy. Helps exploration.
- update_after (int) – Number of env interactions to collect before starting to do gradient descent updates. Ensures replay buffer is full enough for useful updates.
- update_every (int) – Number of env interactions that should elapse between gradient descent updates. Note: Regardless of how long you wait between updates, the ratio of env steps to gradient steps is locked to 1.
- num_test_episodes (int) – Number of episodes to test the deterministic policy at the end of each epoch.
- max_ep_len (int) – Maximum length of trajectory / episode / rollout.
- logger_kwargs (dict) – Keyword args for EpochLogger.
- save_freq (int) – How often (in terms of gap between epochs) to save the current policy and value function.
Saved Model Contents: Tensorflow Version¶
The computation graph saved by the logger includes:
Key | Value |
---|---|
x |
Tensorflow placeholder for state input. |
a |
Tensorflow placeholder for action input. |
mu |
Deterministically computes mean action from the agent, given states in x . |
pi |
Samples an action from the agent, conditioned on states in x . |
q1 |
Gives one action-value estimate for states in x and actions in a . |
q2 |
Gives the other action-value estimate for states in x and actions in a . |
v |
Gives the value estimate for states in x . |
This saved model can be accessed either by
- running the trained policy with the test_policy.py tool,
- or loading the whole saved graph into a program with restore_tf_graph.
Note: for SAC, the correct evaluation policy is given by mu
and not by pi
. The policy pi
may be thought of as the exploration policy, while mu
is the exploitation policy.
References¶
Relevant Papers¶
- Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, Haarnoja et al, 2018
- Soft Actor-Critic Algorithms and Applications, Haarnoja et al, 2018
- Learning to Walk via Deep Reinforcement Learning, Haarnoja et al, 2018
Other Public Implementations¶
- SAC release repo (original “official” codebase)
- Softlearning repo (current “official” codebase)
- Yarats and Kostrikov repo