11"""
2- GPU-side Gumbel-max sampler with optional top-k / top-p filtering .
2+ GPU-side Gumbel-max sampler.
33
44Self-contained sampling utility that can be imported by other models. Lives
55in its own file so it can be reused without pulling in the heavy MoE module.
66
7- All sampling parameters ( ``temperature``, ``top_k``, ``top_p``) are
8- **runtime tensors** so a single exported program can be re-driven with
9- different sampling configurations without re-export.
7+ ``temperature`` is a runtime tensor so a single exported program can be
8+ re-driven with different sampling configurations without re-export.
9+
1010"""
1111
1212from typing import Optional
1717def sample (
1818 logits : torch .Tensor ,
1919 temperature : Optional [torch .Tensor ] = None ,
20- top_k : Optional [torch .Tensor ] = None ,
21- top_p : Optional [torch .Tensor ] = None ,
2220) -> torch .Tensor :
23- """GPU-side Gumbel-max sampler with optional top-k / top-p filtering.
24-
25- All three sampling knobs are *runtime* scalar tensors so the caller can
26- change them between calls without re-exporting the graph. The Python-
27- level ``is None`` checks are static (decided at trace time) and select
28- which subgraph is emitted; once provided, the actual values are pure
29- tensors and the kernels are fully data-driven.
21+ """GPU-side Gumbel-max sampler.
3022
31- When ``temperature``, ``top_k `` and ``top_p`` are all ``None`` ( the
32- eager / eval default), the function is a no-op and returns ``logits``
33- unchanged — useful for callers that just want to inspect raw logits.
23+ When ``temperature`` is ``None `` (the eager / eval default) the function
24+ is a no-op and returns ``logits`` unchanged — useful for callers that
25+ just want to inspect raw logits.
3426
3527 Otherwise it draws from ``softmax(logits / temperature)`` entirely
3628 on-device using the Gumbel-max trick:
@@ -41,58 +33,23 @@ def sample(
4133 float32 logits. The contract is documented as ``[B, V]`` float32 and
4234 callers are expected to ``.float()``-cast before invoking ``sample``.
4335
36+ TODO(gasoonjia): add top-k / top-p filtering support in a follow-up PR.
37+
4438 Args:
4539 logits: ``[B, V]`` float32 logits.
4640 temperature: 0-D or 1-D float tensor (clamped to >= 1e-6 to avoid
47- divide-by-zero). ``None`` skips temperature scaling.
48- top_k: 0-D or 1-D int tensor — keep only the top ``k`` logits.
49- ``None`` skips top-k filtering. ``k >= V`` is also a no-op.
50- top_p: 0-D or 1-D float tensor — nucleus threshold; keep the
51- smallest set of logits whose cumulative softmax probability
52- is >= ``top_p``. ``None`` (or ``>= 1.0``) disables top-p.
41+ divide-by-zero). ``None`` skips temperature scaling and the
42+ sampler returns the unmodified ``logits`` tensor.
5343
5444 Returns:
5545 ``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified
56- ``logits`` tensor when all sampling parameters are ``None``.
46+ ``logits`` tensor when ``temperature`` is ``None``.
5747 """
5848 # No sampling configured — return raw logits.
59- if temperature is None and top_k is None and top_p is None :
49+ if temperature is None :
6050 return logits
6151
62- if temperature is not None :
63- logits = logits / temperature .clamp (min = 1e-6 )
64-
65- # Single sort handles both top-k and top-p filtering — both branches
66- # need descending logits anyway, so we share the sort to keep the
67- # graph small.
68- if top_k is not None or top_p is not None :
69- sorted_logits , sorted_idx = torch .sort (logits , dim = - 1 , descending = True )
70- sorted_remove = torch .zeros_like (sorted_logits , dtype = torch .bool )
71-
72- if top_k is not None :
73- # Position >= k → drop. Works for any tensor k via broadcast;
74- # k >= V naturally becomes a no-op (mask is all-False).
75- pos = torch .arange (sorted_logits .size (- 1 ), device = sorted_logits .device )
76- sorted_remove = sorted_remove | (pos >= top_k .to (pos .dtype ))
77-
78- if top_p is not None :
79- cum_probs = torch .softmax (sorted_logits , dim = - 1 ).cumsum (dim = - 1 )
80- p_remove = cum_probs > top_p
81- # Shift right by one so the highest-prob token is always kept,
82- # even when its single-token prob already exceeds top_p.
83- p_remove = torch .cat (
84- [torch .zeros_like (p_remove [..., :1 ]), p_remove [..., :- 1 ]],
85- dim = - 1 ,
86- )
87- sorted_remove = sorted_remove | p_remove
88-
89- sorted_logits = torch .where (
90- sorted_remove ,
91- torch .full_like (sorted_logits , float ("-inf" )),
92- sorted_logits ,
93- )
94- # Scatter the masked sorted logits back into original token order.
95- logits = torch .empty_like (logits ).scatter_ (- 1 , sorted_idx , sorted_logits )
52+ logits = logits / temperature .clamp (min = 1e-6 )
9653
9754 # Gumbel-max sampling — equivalent to sampling from softmax(logits)
9855 # but fully on-device and CUDA-graph friendly. The 1e-20 epsilons are
0 commit comments