Skip to content

Commit 42ee32d

Browse files
MLX: on-device token sampling with Gumbel-max (#20454)
### Summary Adds token sampling that runs inside the exported .pte for the MLX backend: a model wrapped in SamplingHead returns a sampled token id instead of [B, S, vocab] logits, avoiding the per-step logits copy to host and the host-side softmax+multinomial. Sampling uses Gumbel-max: argmax(logits / temperature + g), g = -log(-log(u)). The only new schema primitive is a random source, RandomBitsNode, the rest reuses existing nodes. Greedy = temperature → 0. temperature is a runtime input; seed is optional. Changes - schema.fbs: new RandomBitsNode (append-only union member, optional seed). - custom_kernel_ops/sample.py: mlx::sample op + register_fake + CPU reference. - ops.py: _sample_handler lowering the Gumbel-max graph. - runtime/MLXInterpreter.h: exec_random_bits + dispatch. - llm/sampling.py: SamplingHead wrapper. - generate.py: None-guard optional compound fields so the optional seed (de)serializes. Notes - Uniform/gumbel computed in fp32 (bf16 rounds the ~1.0 clamp up → log(0)=-inf → poisons argmax). - Tests: custom_kernel_ops/test/test_sample.py, eager parity/distribution/determinism, export+partition lowering, and on-device e2e (incl. a bf16 large-vocab regression). Fixes #20353
1 parent d523daa commit 42ee32d

8 files changed

Lines changed: 713 additions & 1 deletion

File tree

.github/workflows/mlx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ jobs:
8484
backends/mlx/test/test_partitioner.py \
8585
backends/mlx/test/test_serialization_dedup.py \
8686
backends/mlx/test/test_slot_recycling.py \
87+
backends/mlx/test/test_sample.py \
8788
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \
8889
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
8990
-v

backends/mlx/custom_ops.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,54 @@ def gather_qmm_fake(
391391
else:
392392
batch = w.shape[:-2]
393393
return x.new_empty((*batch, M, N))
394+
395+
396+
@torch.library.custom_op("mlx::sample", mutates_args=())
397+
def sample(
398+
logits: Tensor,
399+
temperature: Tensor,
400+
top_p: Tensor,
401+
seed: Optional[Tensor] = None,
402+
) -> Tensor:
403+
"""
404+
Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus).
405+
logits: [B, vocab]
406+
temperature: scalar float tensor (runtime input). temperature <= 0 is
407+
greedy: return argmax(logits) directly (matches the device,
408+
which branches on temperature > 0).
409+
top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every
410+
token, i.e. it is off.
411+
seed: scalar int tensor or None
412+
- tensor -> deterministic, keyed RNG (random::key(seed))
413+
- None -> MLX global KeySequence (non-deterministic)
414+
-> token_id: [B] int64
415+
416+
Host/CPU reference used for export (shape/meta) and distributional checks
417+
only. It is NOT bit-identical to the lowered on-device graph: this uses torch
418+
RNG (plain torch.rand, no uint32/nextafter uniform) while the delegate uses
419+
MLX RNG, so a given seed does not reproduce the same tokens host vs. device.
420+
"""
421+
if float(temperature) <= 0: # matches the device cond (temperature > 0)
422+
return torch.argmax(logits, dim=-1)
423+
# whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties).
424+
scaled = logits.float() / temperature
425+
probs = torch.softmax(scaled, dim=-1)
426+
s_probs, _ = torch.sort(probs, dim=-1, descending=True)
427+
cum = torch.cumsum(s_probs, dim=-1)
428+
keep = (cum - s_probs) <= top_p
429+
thresh = torch.where(keep, s_probs, s_probs.new_tensor(float("inf"))).amin(
430+
dim=-1, keepdim=True
431+
)
432+
scaled = torch.where(probs >= thresh, scaled, scaled.new_tensor(float("-inf")))
433+
if seed is None:
434+
u = torch.rand(scaled.shape) # global RNG
435+
else:
436+
gen = torch.Generator().manual_seed(int(seed.item()))
437+
u = torch.rand(scaled.shape, generator=gen)
438+
gumbel = -torch.log(-torch.log(u))
439+
return torch.argmax(scaled + gumbel, dim=-1)
440+
441+
442+
@torch.library.register_fake("mlx::sample")
443+
def sample_fake(logits, temperature, top_p, seed=None):
444+
return logits.new_empty(logits.shape[:-1], dtype=torch.long)

backends/mlx/llm/sampling.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
13+
class SamplingHead(nn.Module):
14+
"""
15+
Wraps a model that returns logits and samples a token id on-device.
16+
17+
forward(*model_args, temperature, top_k=None, top_p=1.0, seed=None,
18+
**model_kwargs) -> token_id
19+
20+
temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be >= 0;
21+
temperature=0 is greedy (returns argmax, no division).
22+
top_k: not implemented yet (reserved); must be None.
23+
top_p: scalar float tensor in (0, 1] for nucleus sampling. top_p=1.0
24+
(the default) keeps every token, i.e. no filtering. Pass it
25+
as a runtime input to tune per request.
26+
seed: scalar int tensor (seeded) or None (unseeded export)
27+
"""
28+
29+
def __init__(self, model: nn.Module):
30+
super().__init__()
31+
self.model = model
32+
33+
def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):
34+
if top_k is not None:
35+
raise NotImplementedError("top_k sampling is not implemented")
36+
logits = self.model(*args, **kwargs) # [B, S, vocab]
37+
last = logits[:, -1, :] # [B, vocab]
38+
if not isinstance(top_p, torch.Tensor):
39+
top_p = torch.tensor(float(top_p))
40+
return torch.ops.mlx.sample(last, temperature, top_p, seed)

backends/mlx/ops.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020

2121
import torch
2222
from executorch.backends.mlx.builder.op_helpers import (
23+
emit_if_else,
2324
emit_lifted_constant,
2425
emit_quantized_biases,
26+
emit_shape,
2527
parse_dequant_node,
2628
to_mlx_qparams,
2729
torch_dtype_to_scalar_type,
@@ -115,6 +117,7 @@
115117
PartitionNode,
116118
PowerNode,
117119
ProdNode,
120+
RandomBitsNode,
118121
ReciprocalNode,
119122
RemainderNode,
120123
RepeatNode,
@@ -3513,6 +3516,203 @@ def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot:
35133516
return out
35143517

35153518

3519+
@REGISTRY.register(target=[torch.ops.mlx.sample.default])
3520+
def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot:
3521+
"""Gumbel-max sampling: argmax(logits / temperature + gumbel_noise).
3522+
3523+
Reproduces MLX's uniform -> gumbel -> argmax layering in the IR using the
3524+
new RandomBitsNode plus existing elementwise nodes, so a sampled token id is
3525+
produced on-device instead of returning the full logits tensor.
3526+
3527+
temperature == 0 is greedy: an IfNode branches to a plain argmax(logits),
3528+
skipping the sampling chain (so 0 is exact, not the small-epsilon approx).
3529+
"""
3530+
args = P.args(n)
3531+
require_args(args, 3, 4, "mlx.sample")
3532+
require_kwargs(P.kwargs(n), set(), "mlx.sample")
3533+
logits, temperature, top_p = args[0], args[1], args[2]
3534+
seed = args[3] if len(args) > 3 and args[3] is not None else None
3535+
3536+
temp_dt = n.args[1].meta["val"].dtype
3537+
out = P.make_or_get_slot(n)
3538+
3539+
def emit_greedy():
3540+
P.emit(
3541+
ArgmaxNode(
3542+
x=P.slot_to_tid(logits),
3543+
out=P.slot_to_tid(out),
3544+
axis=-1,
3545+
keepdims=False,
3546+
)
3547+
)
3548+
3549+
def emit_sample():
3550+
shape = emit_shape(P, n.args[0], logits)
3551+
3552+
# Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent ->
3553+
# leave RandomBitsNode.seed unset (MLX global RNG).
3554+
seed_field = None
3555+
if seed is not None:
3556+
_, seed_val = P.make_tmp_value_slot()
3557+
P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val)))
3558+
seed_field = P.slot_to_vid(seed_val)
3559+
3560+
# uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95)
3561+
_, bits = P.make_tmp_slot()
3562+
P.emit(
3563+
RandomBitsNode(
3564+
out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field
3565+
)
3566+
)
3567+
_, bits_f = P.make_tmp_slot()
3568+
P.emit(
3569+
AsTypeNode(
3570+
x=P.slot_to_tid(bits),
3571+
out=P.slot_to_tid(bits_f),
3572+
scalar_type=torch_dtype_to_scalar_type(torch.float32),
3573+
)
3574+
)
3575+
umax = emit_lifted_constant(P, 4294967295.0, torch.float32)
3576+
_, div0 = P.make_tmp_slot()
3577+
P.emit(
3578+
DivideNode(
3579+
a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0)
3580+
)
3581+
)
3582+
prev1 = emit_lifted_constant(
3583+
P,
3584+
float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))),
3585+
torch.float32,
3586+
)
3587+
_, clamp = P.make_tmp_slot()
3588+
P.emit(
3589+
MinimumNode(
3590+
a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp)
3591+
)
3592+
)
3593+
# gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf).
3594+
_, l1 = P.make_tmp_slot()
3595+
P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1)))
3596+
_, g1 = P.make_tmp_slot()
3597+
P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1)))
3598+
_, l2 = P.make_tmp_slot()
3599+
P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2)))
3600+
_, g = P.make_tmp_slot()
3601+
P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g)))
3602+
3603+
# sample: argmax(logits / temperature + g) over the vocab axis, in float32
3604+
_, logits_f = P.make_tmp_slot()
3605+
P.emit(
3606+
AsTypeNode(
3607+
x=P.slot_to_tid(logits),
3608+
out=P.slot_to_tid(logits_f),
3609+
scalar_type=torch_dtype_to_scalar_type(torch.float32),
3610+
)
3611+
)
3612+
_, scaled = P.make_tmp_slot()
3613+
P.emit(
3614+
DivideNode(
3615+
a=P.slot_to_tid(logits_f),
3616+
b=P.slot_to_tid(temperature),
3617+
out=P.slot_to_tid(scaled),
3618+
)
3619+
)
3620+
3621+
# top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending.
3622+
_, probs = P.make_tmp_slot()
3623+
P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1))
3624+
_, neg_p = P.make_tmp_slot()
3625+
P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(neg_p)))
3626+
_, sorted_neg = P.make_tmp_slot()
3627+
P.emit(SortNode(x=P.slot_to_tid(neg_p), out=P.slot_to_tid(sorted_neg), axis=-1))
3628+
_, sorted_p = P.make_tmp_slot()
3629+
P.emit(NegNode(x=P.slot_to_tid(sorted_neg), out=P.slot_to_tid(sorted_p)))
3630+
_, cum = P.make_tmp_slot()
3631+
P.emit(CumsumNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(cum), axis=-1))
3632+
_, prefix = P.make_tmp_slot()
3633+
P.emit(
3634+
SubtractNode(
3635+
a=P.slot_to_tid(cum),
3636+
b=P.slot_to_tid(sorted_p),
3637+
out=P.slot_to_tid(prefix),
3638+
)
3639+
)
3640+
# remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0)
3641+
_, remove = P.make_tmp_slot()
3642+
P.emit(
3643+
GreaterNode(
3644+
a=P.slot_to_tid(prefix),
3645+
b=P.slot_to_tid(top_p),
3646+
out=P.slot_to_tid(remove),
3647+
)
3648+
)
3649+
pos_inf = emit_lifted_constant(P, float("inf"), torch.float32)
3650+
_, kept = P.make_tmp_slot()
3651+
P.emit(
3652+
WhereNode(
3653+
condition=P.slot_to_tid(remove),
3654+
x=P.slot_to_tid(pos_inf),
3655+
y=P.slot_to_tid(sorted_p),
3656+
out=P.slot_to_tid(kept),
3657+
)
3658+
)
3659+
# threshold = smallest kept probability (per row)
3660+
_, thresh = P.make_tmp_slot()
3661+
P.emit(
3662+
MinNode(
3663+
x=P.slot_to_tid(kept),
3664+
out=P.slot_to_tid(thresh),
3665+
axes=[-1],
3666+
keepdims=True,
3667+
)
3668+
)
3669+
_, drop = P.make_tmp_slot()
3670+
P.emit(
3671+
LessNode(
3672+
a=P.slot_to_tid(probs),
3673+
b=P.slot_to_tid(thresh),
3674+
out=P.slot_to_tid(drop),
3675+
)
3676+
)
3677+
neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32)
3678+
_, masked = P.make_tmp_slot()
3679+
P.emit(
3680+
WhereNode(
3681+
condition=P.slot_to_tid(drop),
3682+
x=P.slot_to_tid(neg_inf),
3683+
y=P.slot_to_tid(scaled),
3684+
out=P.slot_to_tid(masked),
3685+
)
3686+
)
3687+
3688+
_, noisy = P.make_tmp_slot()
3689+
P.emit(
3690+
AddNode(
3691+
a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy)
3692+
)
3693+
)
3694+
P.emit(
3695+
ArgmaxNode(
3696+
x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False
3697+
)
3698+
)
3699+
3700+
# temperature == 0 -> greedy: IfNode branches to argmax(logits), skipping sampling.
3701+
zero = emit_lifted_constant(P, 0.0, temp_dt)
3702+
_, is_sampling = P.make_tmp_slot()
3703+
P.emit(
3704+
GreaterNode(
3705+
a=P.slot_to_tid(temperature),
3706+
b=P.slot_to_tid(zero),
3707+
out=P.slot_to_tid(is_sampling),
3708+
)
3709+
)
3710+
_, cond_val = P.make_tmp_value_slot()
3711+
P.emit(ItemIntNode(x=P.slot_to_tid(is_sampling), out=P.slot_to_vid(cond_val)))
3712+
emit_if_else(P, P.to_int_or_vid(cond_val), emit_sample, emit_greedy)
3713+
return out
3714+
3715+
35163716
@REGISTRY.register(target=[torch.ops.aten.argmin.default])
35173717
def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot:
35183718
"""Handle aten.argmin - index of min element along axis."""

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,26 @@ exec_argmax(const ArgmaxNode& n, ExecutionState& st, StreamOrDevice s) {
16951695
st.set_tensor(n.out, argmax(x, n.axis, n.keepdims, s));
16961696
}
16971697

1698+
inline void exec_random_bits(
1699+
const RandomBitsNode& n,
1700+
ExecutionState& st,
1701+
StreamOrDevice s) {
1702+
// random::bits supports width (bytes/element) in {1, 2, 4} ->
1703+
// uint8/uint16/uint32.
1704+
if (n.width != 1 && n.width != 2 && n.width != 4) {
1705+
throw std::runtime_error("random_bits: width must be 1, 2, or 4");
1706+
}
1707+
auto shape = to_shape(n.shape, st);
1708+
// uint32 (4 bytes, the widest supported) is a safe upper bound for the guard.
1709+
check_allocation_bounded(shape, uint32, "random_bits");
1710+
std::optional<array> key = std::nullopt;
1711+
if (n.seed.has_value()) {
1712+
key = random::key(
1713+
static_cast<uint64_t>(st.const_value_ref<int32_t>(n.seed.value())));
1714+
}
1715+
st.set_tensor(n.out, random::bits(shape, n.width, key, s));
1716+
}
1717+
16981718
inline void
16991719
exec_argmin(const ArgminNode& n, ExecutionState& st, StreamOrDevice s) {
17001720
const auto& x = st.const_tensor_ref(n.x);
@@ -2057,6 +2077,9 @@ class Interpreter {
20572077
case OpCode::ARGMAX:
20582078
ops::exec_argmax(std::get<ArgmaxNode>(instr.node), st, s);
20592079
break;
2080+
case OpCode::RANDOM_BITS:
2081+
ops::exec_random_bits(std::get<RandomBitsNode>(instr.node), st, s);
2082+
break;
20602083
case OpCode::SLICE_UPDATE:
20612084
ops::exec_slice_update(std::get<SliceUpdateNode>(instr.node), st, s);
20622085
break;

backends/mlx/serialization/schema.fbs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,14 @@ table IfNode {
985985
else_chain_idx: uint32; // index into MLXGraph.instruction_chains
986986
}
987987

988+
table RandomBitsNode {
989+
out: Tid (required);
990+
shape: [IntOrVid] (required);
991+
seed: Vid; // OPTIONAL: present -> random::key(seed);
992+
// absent -> MLX global KeySequence
993+
width: int32 = 4; // bytes per element (4 -> uint32)
994+
}
995+
988996
// Custom Metal kernel execution via mlx::core::fast::metal_kernel().
989997
// Two-phase API:
990998
// 1. Factory: metal_kernel(name, input_names, output_names, source, header,
@@ -1161,7 +1169,8 @@ union OpNode {
11611169
BitwiseAndNode,
11621170
BitwiseOrNode,
11631171
BitwiseXorNode,
1164-
IfNode
1172+
IfNode,
1173+
RandomBitsNode
11651174
// BC: Add new op nodes here (append only)
11661175
}
11671176

0 commit comments

Comments
 (0)