Commit 55f875b
committed
[Pallas] Lower aten gather using one_hot + sum for TPU compatibility, unblocking cross_entropy
TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis
fails during lowering. Instead, implement gather(input, dim, index) as:
mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
result = sum(input * mask, axis=dim, keepdims=True)
Also removes the xfailIfPallas mark from test_cross_entropy since the
gather lowering now works.
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
stack-info: PR: #2060, branch: AmesingFlank/stack/261 parent 1ad88fd commit 55f875b
2 files changed
Lines changed: 170 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2511 | 2511 | | |
2512 | 2512 | | |
2513 | 2513 | | |
| 2514 | + | |
| 2515 | + | |
| 2516 | + | |
| 2517 | + | |
| 2518 | + | |
| 2519 | + | |
| 2520 | + | |
| 2521 | + | |
| 2522 | + | |
| 2523 | + | |
| 2524 | + | |
| 2525 | + | |
| 2526 | + | |
| 2527 | + | |
| 2528 | + | |
| 2529 | + | |
| 2530 | + | |
| 2531 | + | |
| 2532 | + | |
| 2533 | + | |
| 2534 | + | |
| 2535 | + | |
| 2536 | + | |
| 2537 | + | |
| 2538 | + | |
| 2539 | + | |
| 2540 | + | |
| 2541 | + | |
| 2542 | + | |
| 2543 | + | |
| 2544 | + | |
| 2545 | + | |
| 2546 | + | |
| 2547 | + | |
| 2548 | + | |
| 2549 | + | |
| 2550 | + | |
| 2551 | + | |
| 2552 | + | |
| 2553 | + | |
| 2554 | + | |
| 2555 | + | |
| 2556 | + | |
| 2557 | + | |
| 2558 | + | |
| 2559 | + | |
| 2560 | + | |
| 2561 | + | |
| 2562 | + | |
| 2563 | + | |
| 2564 | + | |
| 2565 | + | |
| 2566 | + | |
| 2567 | + | |
| 2568 | + | |
| 2569 | + | |
| 2570 | + | |
| 2571 | + | |
| 2572 | + | |
| 2573 | + | |
| 2574 | + | |
| 2575 | + | |
| 2576 | + | |
| 2577 | + | |
| 2578 | + | |
| 2579 | + | |
| 2580 | + | |
| 2581 | + | |
| 2582 | + | |
| 2583 | + | |
| 2584 | + | |
| 2585 | + | |
| 2586 | + | |
| 2587 | + | |
| 2588 | + | |
2514 | 2589 | | |
2515 | 2590 | | |
2516 | 2591 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3466 | 3466 | | |
3467 | 3467 | | |
3468 | 3468 | | |
| 3469 | + | |
| 3470 | + | |
| 3471 | + | |
| 3472 | + | |
| 3473 | + | |
| 3474 | + | |
| 3475 | + | |
| 3476 | + | |
| 3477 | + | |
| 3478 | + | |
| 3479 | + | |
| 3480 | + | |
| 3481 | + | |
| 3482 | + | |
| 3483 | + | |
| 3484 | + | |
| 3485 | + | |
| 3486 | + | |
| 3487 | + | |
| 3488 | + | |
| 3489 | + | |
| 3490 | + | |
| 3491 | + | |
| 3492 | + | |
| 3493 | + | |
| 3494 | + | |
| 3495 | + | |
| 3496 | + | |
| 3497 | + | |
| 3498 | + | |
| 3499 | + | |
| 3500 | + | |
| 3501 | + | |
| 3502 | + | |
| 3503 | + | |
| 3504 | + | |
| 3505 | + | |
| 3506 | + | |
| 3507 | + | |
| 3508 | + | |
| 3509 | + | |
| 3510 | + | |
| 3511 | + | |
| 3512 | + | |
| 3513 | + | |
| 3514 | + | |
| 3515 | + | |
| 3516 | + | |
| 3517 | + | |
| 3518 | + | |
| 3519 | + | |
| 3520 | + | |
| 3521 | + | |
| 3522 | + | |
| 3523 | + | |
| 3524 | + | |
| 3525 | + | |
| 3526 | + | |
| 3527 | + | |
| 3528 | + | |
| 3529 | + | |
| 3530 | + | |
| 3531 | + | |
| 3532 | + | |
| 3533 | + | |
| 3534 | + | |
| 3535 | + | |
| 3536 | + | |
| 3537 | + | |
| 3538 | + | |
| 3539 | + | |
| 3540 | + | |
| 3541 | + | |
| 3542 | + | |
| 3543 | + | |
| 3544 | + | |
| 3545 | + | |
| 3546 | + | |
| 3547 | + | |
| 3548 | + | |
| 3549 | + | |
| 3550 | + | |
| 3551 | + | |
| 3552 | + | |
| 3553 | + | |
| 3554 | + | |
| 3555 | + | |
| 3556 | + | |
| 3557 | + | |
| 3558 | + | |
| 3559 | + | |
| 3560 | + | |
| 3561 | + | |
| 3562 | + | |
| 3563 | + | |
3469 | 3564 | | |
3470 | 3565 | | |
3471 | 3566 | | |
| |||
0 commit comments