Skip to content

[Pallas] Implement indirect gather via exact one-hot matmul#2035

Closed
thcmbs wants to merge 1 commit into
pytorch:mainfrom
thcmbs:pallas-copy3d-gather
Closed

[Pallas] Implement indirect gather via exact one-hot matmul#2035
thcmbs wants to merge 1 commit into
pytorch:mainfrom
thcmbs:pallas-copy3d-gather

Conversation

@thcmbs
Copy link
Copy Markdown
Collaborator

@thcmbs thcmbs commented Apr 17, 2026

Pallas lacks native gather instructions. This PR implements indirect gather (tensor[idx]) by lowering to a one-hot matrix multiplication.

  • One-hot Matmul: Gather is implemented as one_hot(idx, V) @ table.
  • Exact Precision: Computations are forced to float32 with Precision.HIGHEST to prevent MXU truncation errors during accumulation.
  • VMEM Protection: The entire table must reside in VMEM. Added a hard 16 MiB threshold. Tables exceeding this limit will fail fast with an explicit error instead of a generic Pallas OOM.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 17, 2026
@thcmbs thcmbs force-pushed the pallas-copy3d-gather branch 5 times, most recently from b50d87c to 1e43767 Compare April 17, 2026 07:06
@thcmbs thcmbs marked this pull request as ready for review April 17, 2026 07:10
@thcmbs thcmbs marked this pull request as draft April 17, 2026 08:56
@thcmbs thcmbs force-pushed the pallas-copy3d-gather branch from 91e3aec to eb00094 Compare April 17, 2026 09:41
@thcmbs thcmbs closed this May 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant