Extra Material: Tensorflow Policy Gradient Implementation Examples

Implementing the Simplest Policy Gradient

We give a short Tensorflow implementation of this simple version of the policy gradient algorithm in spinup/examples/tf1/pg_math/1_simple_pg.py. (It can also be viewed on github.) It is only 122 lines long, so we highly recommend reading through it in depth. While we won’t go through the entirety of the code here, we’ll highlight and explain a few important pieces.

1. Making the Policy Network.

25
26
27
28
29
30
# make core of policy network
obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
logits = mlp(obs_ph, sizes=hidden_sizes+[n_acts])

# make action selection op (outputs int actions, sampled from policy)
actions = tf.squeeze(tf.multinomial(logits=logits,num_samples=1), axis=1)

This block builds a feedforward neural network categorical policy. (See the Stochastic Policies section in Part 1 for a refresher.) The logits tensor can be used to construct log-probabilities and probabilities for actions, and the actions tensor samples actions based on the probabilities implied by logits.

2. Making the Loss Function.

32
33
34
35
36
37
# make loss function whose gradient, for the right data, is policy gradient
weights_ph = tf.placeholder(shape=(None,), dtype=tf.float32)
act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
action_masks = tf.one_hot(act_ph, n_acts)
log_probs = tf.reduce_sum(action_masks * tf.nn.log_softmax(logits), axis=1)
loss = -tf.reduce_mean(weights_ph * log_probs)

In this block, we build a “loss” function for the policy gradient algorithm. When the right data is plugged in, the gradient of this loss is equal to the policy gradient. The right data means a set of (state, action, weight) tuples collected while acting according to the current policy, where the weight for a state-action pair is the return from the episode to which it belongs. (Although as we will show in later subsections, there are other values you can plug in for the weight which also work correctly.)

You Should Know

Even though we describe this as a loss function, it is not a loss function in the typical sense from supervised learning. There are two main differences from standard loss functions.

1. The data distribution depends on the parameters. A loss function is usually defined on a fixed data distribution which is independent of the parameters we aim to optimize. Not so here, where the data must be sampled on the most recent policy.

2. It doesn’t measure performance. A loss function usually evaluates the performance metric that we care about. Here, we care about expected return, J(\pi_{\theta}), but our “loss” function does not approximate this at all, even in expectation. This “loss” function is only useful to us because, when evaluated at the current parameters, with data generated by the current parameters, it has the negative gradient of performance.

But after that first step of gradient descent, there is no more connection to performance. This means that minimizing this “loss” function, for a given batch of data, has no guarantee whatsoever of improving expected return. You can send this loss to -\infty and policy performance could crater; in fact, it usually will. Sometimes a deep RL researcher might describe this outcome as the policy “overfitting” to a batch of data. This is descriptive, but should not be taken literally because it does not refer to generalization error.

We raise this point because it is common for ML practitioners to interpret a loss function as a useful signal during training—”if the loss goes down, all is well.” In policy gradients, this intuition is wrong, and you should only care about average return. The loss function means nothing.

You Should Know

The approach used here to make the log_probs tensor—creating an action mask, and using it to select out particular log probabilities—only works for categorical policies. It does not work in general.

3. Running One Epoch of Training.

 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if not(finished_rendering_this_epoch):
                env.render()

            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
            obs, rew, done, _ = env.step(act)

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                batch_weights += [ep_ret] * ep_len

                # reset episode-specific variables
                obs, done, ep_rews = env.reset(), False, []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break

        # take a single policy gradient update step
        batch_loss, _ = sess.run([loss, train_op],
                                 feed_dict={
                                    obs_ph: np.array(batch_obs),
                                    act_ph: np.array(batch_acts),
                                    weights_ph: np.array(batch_weights)
                                 })
        return batch_loss, batch_rets, batch_lens

The train_one_epoch() function runs one “epoch” of policy gradient, which we define to be

  1. the experience collection step (L62-97), where the agent acts for some number of episodes in the environment using the most recent policy, followed by
  2. a single policy gradient update step (L99-105).

The main loop of the algorithm just repeatedly calls train_one_epoch().

Implementing Reward-to-Go Policy Gradient

We give a short Tensorflow implementation of the reward-to-go policy gradient in spinup/examples/tf1/pg_math/2_rtg_pg.py. (It can also be viewed on github.)

The only thing that has changed from 1_simple_pg.py is that we now use different weights in the loss function. The code modification is very slight: we add a new function, and change two other lines. The new function is:

12
13
14
15
16
17
def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

And then we tweak the old L86-87 from:

86
87
                # the weight for each logprob(a|s) is R(tau)
                batch_weights += [ep_ret] * ep_len

to:

93
94
                # the weight for each logprob(a_t|s_t) is reward-to-go from t
                batch_weights += list(reward_to_go(ep_rews))