@@ -253,11 +253,18 @@ Hence, we skip describing its contents.
253253 class Agent :
254254 ...
255255 def finish_episode (self ):
256- # joins probs and rewards from different observers into lists
257- R, probs, rewards = 0 , [], []
256+ # joins probs and per-observer discounted returns into flat lists;
257+ # returns are computed per observer so trajectories from different
258+ # observers don't bleed into each other through the reverse cumulative sum
259+ probs, returns = [], []
258260 for ob_id in self .rewards:
261+ R = 0
262+ ob_returns = []
263+ for r in self .rewards[ob_id][::- 1 ]:
264+ R = r + args.gamma * R
265+ ob_returns.insert(0 , R)
266+ returns.extend(ob_returns)
259267 probs.extend(self .saved_log_probs[ob_id])
260- rewards.extend(self .rewards[ob_id])
261268
262269 # use the minimum observer reward to calculate the running reward
263270 min_reward = min ([sum (self .rewards[ob_id]) for ob_id in self .rewards])
@@ -268,10 +275,7 @@ Hence, we skip describing its contents.
268275 self .rewards[ob_id] = []
269276 self .saved_log_probs[ob_id] = []
270277
271- policy_loss, returns = [], []
272- for r in rewards[::- 1 ]:
273- R = r + args.gamma * R
274- returns.insert(0 , R)
278+ policy_loss = []
275279 returns = torch.tensor(returns)
276280 returns = (returns - returns.mean()) / (returns.std() + self .eps)
277281 for log_prob, R in zip (probs, returns):
0 commit comments