Skip to content

Commit b217c8a

Browse files
committed
intermediate/rpc_tutorial: fix REINFORCE return calculation across observers
1 parent 03706c5 commit b217c8a

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

intermediate_source/rpc_tutorial.rst

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,18 @@ Hence, we skip describing its contents.
253253
class Agent:
254254
...
255255
def finish_episode(self):
256-
# joins probs and rewards from different observers into lists
257-
R, probs, rewards = 0, [], []
256+
# joins probs and per-observer discounted returns into flat lists;
257+
# returns are computed per observer so trajectories from different
258+
# observers don't bleed into each other through the reverse cumulative sum
259+
probs, returns = [], []
258260
for ob_id in self.rewards:
261+
R = 0
262+
ob_returns = []
263+
for r in self.rewards[ob_id][::-1]:
264+
R = r + args.gamma * R
265+
ob_returns.insert(0, R)
266+
returns.extend(ob_returns)
259267
probs.extend(self.saved_log_probs[ob_id])
260-
rewards.extend(self.rewards[ob_id])
261268
262269
# use the minimum observer reward to calculate the running reward
263270
min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
@@ -268,10 +275,7 @@ Hence, we skip describing its contents.
268275
self.rewards[ob_id] = []
269276
self.saved_log_probs[ob_id] = []
270277
271-
policy_loss, returns = [], []
272-
for r in rewards[::-1]:
273-
R = r + args.gamma * R
274-
returns.insert(0, R)
278+
policy_loss = []
275279
returns = torch.tensor(returns)
276280
returns = (returns - returns.mean()) / (returns.std() + self.eps)
277281
for log_prob, R in zip(probs, returns):

0 commit comments

Comments
 (0)