From 131eb2bef073521d587d29cf31702447d4278f2c Mon Sep 17 00:00:00 2001 From: Yusuke-Mukuta Date: Wed, 3 Jun 2026 16:14:07 +0000 Subject: [PATCH] Fix AlphaZero auto-reset carry state --- examples/alphazero/train.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/alphazero/train.py b/examples/alphazero/train.py index a80e5d9ec..019b40ad5 100644 --- a/examples/alphazero/train.py +++ b/examples/alphazero/train.py @@ -144,13 +144,20 @@ def step_fn(state, key) -> SelfplayOutput: actor = state.current_player keys = jax.random.split(key2, batch_size) state = jax.vmap(auto_reset(env.step, env.init))(state, policy_output.action, keys) + reward = state.rewards[jnp.arange(state.rewards.shape[0]), actor] + terminated = state.terminated discount = -1.0 * jnp.ones_like(value) - discount = jnp.where(state.terminated, 0.0, discount) - return state, SelfplayOutput( + discount = jnp.where(terminated, 0.0, discount) + next_state = state.replace( + rewards=jnp.zeros_like(state.rewards), + terminated=jnp.zeros_like(state.terminated), + truncated=jnp.zeros_like(state.truncated), + ) + return next_state, SelfplayOutput( obs=observation, action_weights=policy_output.action_weights, - reward=state.rewards[jnp.arange(state.rewards.shape[0]), actor], - terminated=state.terminated, + reward=reward, + terminated=terminated, discount=discount, )