You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
stack-info: PR: #2061, branch: AmesingFlank/stack/27
0 commit comments