|
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | from executorch.backends.mlx.builder.op_helpers import ( |
| 23 | + emit_if_else, |
23 | 24 | emit_lifted_constant, |
24 | 25 | emit_quantized_biases, |
| 26 | + emit_shape, |
25 | 27 | parse_dequant_node, |
26 | 28 | to_mlx_qparams, |
27 | 29 | torch_dtype_to_scalar_type, |
|
115 | 117 | PartitionNode, |
116 | 118 | PowerNode, |
117 | 119 | ProdNode, |
| 120 | + RandomBitsNode, |
118 | 121 | ReciprocalNode, |
119 | 122 | RemainderNode, |
120 | 123 | RepeatNode, |
@@ -3513,6 +3516,203 @@ def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: |
3513 | 3516 | return out |
3514 | 3517 |
|
3515 | 3518 |
|
| 3519 | +@REGISTRY.register(target=[torch.ops.mlx.sample.default]) |
| 3520 | +def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot: |
| 3521 | + """Gumbel-max sampling: argmax(logits / temperature + gumbel_noise). |
| 3522 | +
|
| 3523 | + Reproduces MLX's uniform -> gumbel -> argmax layering in the IR using the |
| 3524 | + new RandomBitsNode plus existing elementwise nodes, so a sampled token id is |
| 3525 | + produced on-device instead of returning the full logits tensor. |
| 3526 | +
|
| 3527 | + temperature == 0 is greedy: an IfNode branches to a plain argmax(logits), |
| 3528 | + skipping the sampling chain (so 0 is exact, not the small-epsilon approx). |
| 3529 | + """ |
| 3530 | + args = P.args(n) |
| 3531 | + require_args(args, 3, 4, "mlx.sample") |
| 3532 | + require_kwargs(P.kwargs(n), set(), "mlx.sample") |
| 3533 | + logits, temperature, top_p = args[0], args[1], args[2] |
| 3534 | + seed = args[3] if len(args) > 3 and args[3] is not None else None |
| 3535 | + |
| 3536 | + temp_dt = n.args[1].meta["val"].dtype |
| 3537 | + out = P.make_or_get_slot(n) |
| 3538 | + |
| 3539 | + def emit_greedy(): |
| 3540 | + P.emit( |
| 3541 | + ArgmaxNode( |
| 3542 | + x=P.slot_to_tid(logits), |
| 3543 | + out=P.slot_to_tid(out), |
| 3544 | + axis=-1, |
| 3545 | + keepdims=False, |
| 3546 | + ) |
| 3547 | + ) |
| 3548 | + |
| 3549 | + def emit_sample(): |
| 3550 | + shape = emit_shape(P, n.args[0], logits) |
| 3551 | + |
| 3552 | + # Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent -> |
| 3553 | + # leave RandomBitsNode.seed unset (MLX global RNG). |
| 3554 | + seed_field = None |
| 3555 | + if seed is not None: |
| 3556 | + _, seed_val = P.make_tmp_value_slot() |
| 3557 | + P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val))) |
| 3558 | + seed_field = P.slot_to_vid(seed_val) |
| 3559 | + |
| 3560 | + # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95) |
| 3561 | + _, bits = P.make_tmp_slot() |
| 3562 | + P.emit( |
| 3563 | + RandomBitsNode( |
| 3564 | + out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field |
| 3565 | + ) |
| 3566 | + ) |
| 3567 | + _, bits_f = P.make_tmp_slot() |
| 3568 | + P.emit( |
| 3569 | + AsTypeNode( |
| 3570 | + x=P.slot_to_tid(bits), |
| 3571 | + out=P.slot_to_tid(bits_f), |
| 3572 | + scalar_type=torch_dtype_to_scalar_type(torch.float32), |
| 3573 | + ) |
| 3574 | + ) |
| 3575 | + umax = emit_lifted_constant(P, 4294967295.0, torch.float32) |
| 3576 | + _, div0 = P.make_tmp_slot() |
| 3577 | + P.emit( |
| 3578 | + DivideNode( |
| 3579 | + a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0) |
| 3580 | + ) |
| 3581 | + ) |
| 3582 | + prev1 = emit_lifted_constant( |
| 3583 | + P, |
| 3584 | + float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))), |
| 3585 | + torch.float32, |
| 3586 | + ) |
| 3587 | + _, clamp = P.make_tmp_slot() |
| 3588 | + P.emit( |
| 3589 | + MinimumNode( |
| 3590 | + a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp) |
| 3591 | + ) |
| 3592 | + ) |
| 3593 | + # gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf). |
| 3594 | + _, l1 = P.make_tmp_slot() |
| 3595 | + P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1))) |
| 3596 | + _, g1 = P.make_tmp_slot() |
| 3597 | + P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1))) |
| 3598 | + _, l2 = P.make_tmp_slot() |
| 3599 | + P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2))) |
| 3600 | + _, g = P.make_tmp_slot() |
| 3601 | + P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g))) |
| 3602 | + |
| 3603 | + # sample: argmax(logits / temperature + g) over the vocab axis, in float32 |
| 3604 | + _, logits_f = P.make_tmp_slot() |
| 3605 | + P.emit( |
| 3606 | + AsTypeNode( |
| 3607 | + x=P.slot_to_tid(logits), |
| 3608 | + out=P.slot_to_tid(logits_f), |
| 3609 | + scalar_type=torch_dtype_to_scalar_type(torch.float32), |
| 3610 | + ) |
| 3611 | + ) |
| 3612 | + _, scaled = P.make_tmp_slot() |
| 3613 | + P.emit( |
| 3614 | + DivideNode( |
| 3615 | + a=P.slot_to_tid(logits_f), |
| 3616 | + b=P.slot_to_tid(temperature), |
| 3617 | + out=P.slot_to_tid(scaled), |
| 3618 | + ) |
| 3619 | + ) |
| 3620 | + |
| 3621 | + # top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending. |
| 3622 | + _, probs = P.make_tmp_slot() |
| 3623 | + P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1)) |
| 3624 | + _, neg_p = P.make_tmp_slot() |
| 3625 | + P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(neg_p))) |
| 3626 | + _, sorted_neg = P.make_tmp_slot() |
| 3627 | + P.emit(SortNode(x=P.slot_to_tid(neg_p), out=P.slot_to_tid(sorted_neg), axis=-1)) |
| 3628 | + _, sorted_p = P.make_tmp_slot() |
| 3629 | + P.emit(NegNode(x=P.slot_to_tid(sorted_neg), out=P.slot_to_tid(sorted_p))) |
| 3630 | + _, cum = P.make_tmp_slot() |
| 3631 | + P.emit(CumsumNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(cum), axis=-1)) |
| 3632 | + _, prefix = P.make_tmp_slot() |
| 3633 | + P.emit( |
| 3634 | + SubtractNode( |
| 3635 | + a=P.slot_to_tid(cum), |
| 3636 | + b=P.slot_to_tid(sorted_p), |
| 3637 | + out=P.slot_to_tid(prefix), |
| 3638 | + ) |
| 3639 | + ) |
| 3640 | + # remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0) |
| 3641 | + _, remove = P.make_tmp_slot() |
| 3642 | + P.emit( |
| 3643 | + GreaterNode( |
| 3644 | + a=P.slot_to_tid(prefix), |
| 3645 | + b=P.slot_to_tid(top_p), |
| 3646 | + out=P.slot_to_tid(remove), |
| 3647 | + ) |
| 3648 | + ) |
| 3649 | + pos_inf = emit_lifted_constant(P, float("inf"), torch.float32) |
| 3650 | + _, kept = P.make_tmp_slot() |
| 3651 | + P.emit( |
| 3652 | + WhereNode( |
| 3653 | + condition=P.slot_to_tid(remove), |
| 3654 | + x=P.slot_to_tid(pos_inf), |
| 3655 | + y=P.slot_to_tid(sorted_p), |
| 3656 | + out=P.slot_to_tid(kept), |
| 3657 | + ) |
| 3658 | + ) |
| 3659 | + # threshold = smallest kept probability (per row) |
| 3660 | + _, thresh = P.make_tmp_slot() |
| 3661 | + P.emit( |
| 3662 | + MinNode( |
| 3663 | + x=P.slot_to_tid(kept), |
| 3664 | + out=P.slot_to_tid(thresh), |
| 3665 | + axes=[-1], |
| 3666 | + keepdims=True, |
| 3667 | + ) |
| 3668 | + ) |
| 3669 | + _, drop = P.make_tmp_slot() |
| 3670 | + P.emit( |
| 3671 | + LessNode( |
| 3672 | + a=P.slot_to_tid(probs), |
| 3673 | + b=P.slot_to_tid(thresh), |
| 3674 | + out=P.slot_to_tid(drop), |
| 3675 | + ) |
| 3676 | + ) |
| 3677 | + neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32) |
| 3678 | + _, masked = P.make_tmp_slot() |
| 3679 | + P.emit( |
| 3680 | + WhereNode( |
| 3681 | + condition=P.slot_to_tid(drop), |
| 3682 | + x=P.slot_to_tid(neg_inf), |
| 3683 | + y=P.slot_to_tid(scaled), |
| 3684 | + out=P.slot_to_tid(masked), |
| 3685 | + ) |
| 3686 | + ) |
| 3687 | + |
| 3688 | + _, noisy = P.make_tmp_slot() |
| 3689 | + P.emit( |
| 3690 | + AddNode( |
| 3691 | + a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy) |
| 3692 | + ) |
| 3693 | + ) |
| 3694 | + P.emit( |
| 3695 | + ArgmaxNode( |
| 3696 | + x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False |
| 3697 | + ) |
| 3698 | + ) |
| 3699 | + |
| 3700 | + # temperature == 0 -> greedy: IfNode branches to argmax(logits), skipping sampling. |
| 3701 | + zero = emit_lifted_constant(P, 0.0, temp_dt) |
| 3702 | + _, is_sampling = P.make_tmp_slot() |
| 3703 | + P.emit( |
| 3704 | + GreaterNode( |
| 3705 | + a=P.slot_to_tid(temperature), |
| 3706 | + b=P.slot_to_tid(zero), |
| 3707 | + out=P.slot_to_tid(is_sampling), |
| 3708 | + ) |
| 3709 | + ) |
| 3710 | + _, cond_val = P.make_tmp_value_slot() |
| 3711 | + P.emit(ItemIntNode(x=P.slot_to_tid(is_sampling), out=P.slot_to_vid(cond_val))) |
| 3712 | + emit_if_else(P, P.to_int_or_vid(cond_val), emit_sample, emit_greedy) |
| 3713 | + return out |
| 3714 | + |
| 3715 | + |
3516 | 3716 | @REGISTRY.register(target=[torch.ops.aten.argmin.default]) |
3517 | 3717 | def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot: |
3518 | 3718 | """Handle aten.argmin - index of min element along axis.""" |
|
0 commit comments