diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index a80e5d9ec..019b40ad5 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -144,13 +144,20 @@ def step_fn(state, key) -> SelfplayOutput: actor = state.current_player keys = jax.random.split(key2, batch_size) state = jax.vmap(auto_reset(env.step, env.init))(state, policy_output.action, keys) + reward = state.rewards[jnp.arange(state.rewards.shape[0]), actor] + terminated = state.terminated discount = -1.0 * jnp.ones_like(value) - discount = jnp.where(state.terminated, 0.0, discount) - return state, SelfplayOutput( + discount = jnp.where(terminated, 0.0, discount) + next_state = state.replace( + rewards=jnp.zeros_like(state.rewards), + terminated=jnp.zeros_like(state.terminated), + truncated=jnp.zeros_like(state.truncated), + ) + return next_state, SelfplayOutput( obs=observation, action_weights=policy_output.action_weights, - reward=state.rewards[jnp.arange(state.rewards.shape[0]), actor], - terminated=state.terminated, + reward=reward, + terminated=terminated, discount=discount, )