@@ -95,24 +95,28 @@ def update(self, batch: list[dict]) -> dict[str, float]:
9595
9696 for item in batch :
9797 posterior_state = self .observe (self .init_state (batch_size = 1 ), item ["obs" ])
98- pred_state , _pred_obs , pred_reward , pred_done , _aux = self .predict (
99- posterior_state , int (item ["action" ])
100- )
101- del pred_state
98+ action_tensor = self ._action_tensor (int (item ["action" ]))
99+ prior_stats = self .prior (torch .cat ([posterior_state ["h" ], action_tensor ], dim = - 1 ))
100+ mean , logvar = torch .chunk (prior_stats , 2 , dim = - 1 )
101+ std = torch .exp (0.5 * logvar ).clamp (min = 1e-4 )
102+ eps = torch .randn_like (std )
103+ z = mean + eps * std
104+ h = self .gru (z , posterior_state ["h" ])
105+
106+ pred_reward = self .reward_head (h ).squeeze (- 1 )
107+ pred_done_prob = torch .sigmoid (self .done_head (h )).squeeze (- 1 )
102108
103109 target_reward = torch .tensor ([item ["reward" ]], device = self .device )
104110 target_done = torch .tensor ([float (item ["done" ])], device = self .device )
105111
106- reward_loss = (
107- (torch .tensor ([pred_reward ], device = self .device ) - target_reward ).pow (2 ).mean ()
108- )
109- done_loss = (
110- (torch .tensor ([float (pred_done )], device = self .device ) - target_done ).pow (2 ).mean ()
111- )
112+ reward_loss = (pred_reward - target_reward ).pow (2 ).mean ()
113+ done_loss = (pred_done_prob - target_done ).pow (2 ).mean ()
112114
113- mean = posterior_state ["mean" ]
114- logvar = posterior_state ["logvar" ]
115- kl = - 0.5 * torch .mean (1 + logvar - mean .pow (2 ) - logvar .exp ())
115+ posterior_mean = posterior_state ["mean" ]
116+ posterior_logvar = posterior_state ["logvar" ]
117+ kl = - 0.5 * torch .mean (
118+ 1 + posterior_logvar - posterior_mean .pow (2 ) - posterior_logvar .exp ()
119+ )
116120
117121 total = total + reward_loss + done_loss + 0.1 * kl
118122 kl_total = kl_total + kl
0 commit comments