From 0590b7834b9d73bd486cb593254c91ea0a6d3b4a Mon Sep 17 00:00:00 2001 From: Jesse Rosenstock Date: Mon, 25 May 2026 21:10:55 +0200 Subject: [PATCH] Fix jax.random.categorical axis parameter for JAX 0.10.1 compatibility `jax.random.categorical` now strictly validates the `axis` parameter against the actual rank of the logits array. Passing `axis=1` on a 1D legal_action_mask raised `IndexError: index 1 is out of bounds for axis 0 with size 1`, breaking `test_api` across all games. Fix by using `axis=-1`, which selects the last axis and works correctly for both unbatched (1D) and batched (2D) inputs. See: https://docs.jax.dev/en/latest/_autosummary/jax.random.categorical.html --- pgx/experimental/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgx/experimental/utils.py b/pgx/experimental/utils.py index 79d5f18e3..e3d4c110d 100644 --- a/pgx/experimental/utils.py +++ b/pgx/experimental/utils.py @@ -13,4 +13,4 @@ def act_randomly(rng: PRNGKey, legal_action_mask: Array) -> Array: "Note that codes under pgx.experimental are subject to change without notice." ) logits = jnp.log(legal_action_mask.astype(jnp.float32)) - return jax.random.categorical(rng, logits=logits, axis=1) + return jax.random.categorical(rng, logits=logits, axis=-1)