diff --git a/pgx/experimental/utils.py b/pgx/experimental/utils.py index 79d5f18e3..e3d4c110d 100644 --- a/pgx/experimental/utils.py +++ b/pgx/experimental/utils.py @@ -13,4 +13,4 @@ def act_randomly(rng: PRNGKey, legal_action_mask: Array) -> Array: "Note that codes under pgx.experimental are subject to change without notice." ) logits = jnp.log(legal_action_mask.astype(jnp.float32)) - return jax.random.categorical(rng, logits=logits, axis=1) + return jax.random.categorical(rng, logits=logits, axis=-1)