Skip to content

[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)#2061

Draft
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/26from
AmesingFlank/stack/27
Draft

[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)#2061
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/26from
AmesingFlank/stack/27

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Apr 20, 2026

Stacked PRs:


[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)

This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 noreply@anthropic.com

AmesingFlank added a commit that referenced this pull request Apr 20, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/26 branch from 1a3b7f5 to 696b52e Compare April 20, 2026 21:39
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 3a1edd6 to 15cbc51 Compare April 20, 2026 21:39
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 20, 2026
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 21:42
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main April 20, 2026 22:00
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 15cbc51 to b1ca465 Compare April 20, 2026 22:00
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 April 20, 2026 22:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant