Skip to content

Fix jax.random.categorical axis parameter for JAX 0.10.1 compatibility#1314

Open
jmr wants to merge 1 commit into
sotetsuk:mainfrom
jmr:jax-random
Open

Fix jax.random.categorical axis parameter for JAX 0.10.1 compatibility#1314
jmr wants to merge 1 commit into
sotetsuk:mainfrom
jmr:jax-random

Conversation

@jmr

@jmr jmr commented May 25, 2026

Copy link
Copy Markdown

jax.random.categorical now strictly validates the axis parameter against the actual rank of the logits array. Passing axis=1 on a 1D legal_action_mask raised IndexError: index 1 is out of bounds for axis 0 with size 1, breaking test_api across all games.

Fix by using axis=-1, which selects the last axis and works correctly for both unbatched (1D) and batched (2D) inputs.

See: https://docs.jax.dev/en/latest/_autosummary/jax.random.categorical.html

`jax.random.categorical` now strictly validates the `axis` parameter
against the actual rank of the logits array. Passing `axis=1` on a
1D legal_action_mask raised `IndexError: index 1 is out of bounds for
axis 0 with size 1`, breaking `test_api` across all games.

Fix by using `axis=-1`, which selects the last axis and works correctly
for both unbatched (1D) and batched (2D) inputs.

See: https://docs.jax.dev/en/latest/_autosummary/jax.random.categorical.html
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant