Skip to content

Commit 61d47aa

Browse files
committed
move top-p and top-k suport into a individual PR
1 parent eff4294 commit 61d47aa

5 files changed

Lines changed: 49 additions & 255 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -770,23 +770,10 @@ def _export_cuda(model, config, args):
770770
decode_tokens = torch.tensor([[0]], dtype=torch.long)
771771
decode_pos = torch.tensor([0], dtype=torch.long)
772772
decode_temperature = torch.tensor([1.0], dtype=torch.float32)
773-
# top_k / top_p are runtime scalar tensors (parallel to temperature) so
774-
# the same .pte can be re-driven with different sampling configurations
775-
# without re-export. Default examples are no-op values: top_k=V (keep
776-
# all tokens), top_p=1.0 (keep full nucleus). Callers override them at
777-
# runtime by binding different scalar tensors.
778-
decode_top_k = torch.tensor(config.vocab_size, dtype=torch.int64)
779-
decode_top_p = torch.tensor(1.0, dtype=torch.float32)
780773
with torch.no_grad():
781774
decode_ep = export(
782775
model,
783-
(
784-
decode_tokens,
785-
decode_pos,
786-
decode_temperature,
787-
decode_top_k,
788-
decode_top_p,
789-
),
776+
(decode_tokens, decode_pos, decode_temperature),
790777
strict=True,
791778
)
792779
print("Decode export successful!")
@@ -803,26 +790,16 @@ def _export_cuda(model, config, args):
803790
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)
804791
prefill_pos = torch.arange(example_prefill_len, dtype=torch.long)
805792
prefill_temperature = torch.tensor([1.0], dtype=torch.float32)
806-
prefill_top_k = torch.tensor(config.vocab_size, dtype=torch.int64)
807-
prefill_top_p = torch.tensor(1.0, dtype=torch.float32)
808793
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
809794
prefill_dynamic_shapes = (
810795
{1: seq_dim}, # tokens
811796
{0: seq_dim}, # input_pos
812797
None, # temperature (static scalar tensor)
813-
None, # top_k (static scalar tensor — runtime-bindable)
814-
None, # top_p (static scalar tensor — runtime-bindable)
815798
)
816799
with torch.no_grad():
817800
prefill_ep = export(
818801
model,
819-
(
820-
prefill_tokens,
821-
prefill_pos,
822-
prefill_temperature,
823-
prefill_top_k,
824-
prefill_top_p,
825-
),
802+
(prefill_tokens, prefill_pos, prefill_temperature),
826803
dynamic_shapes=prefill_dynamic_shapes,
827804
strict=True,
828805
)

examples/models/qwen3_5_moe/main.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,6 @@ DEFINE_string(
3737
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
3838
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
3939
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
40-
DEFINE_int64(
41-
top_k,
42-
-1,
43-
"Top-k sampling cutoff (<=0 = no-op default of vocab_size, keeps all tokens).");
44-
DEFINE_double(
45-
top_p,
46-
1.0,
47-
"Top-p (nucleus) sampling threshold. 1.0 = no-op (keeps full nucleus).");
4840

4941
namespace llm = ::executorch::extension::llm;
5042
using ::executorch::extension::from_blob;
@@ -206,22 +198,6 @@ int main(int argc, char** argv) {
206198
auto temp_tensor =
207199
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
208200

209-
// top_k / top_p are 0-D scalar tensors matching the export-time signature
210-
// (see examples/models/qwen3_5_moe/export.py). The default flag values
211-
// (top_k = vocab_size, top_p = 1.0) are mathematical no-ops: the sort+
212-
// scatter subgraph still runs (it was traced into the graph at export
213-
// time), but produces all-False filter masks so logits pass through
214-
// unchanged. Override at runtime to enable real filtering.
215-
int64_t vocab_size = metadata.count(llm::kVocabSize)
216-
? metadata[llm::kVocabSize]
217-
: static_cast<int64_t>(tokenizer->vocab_size());
218-
int64_t top_k_val = (FLAGS_top_k <= 0) ? vocab_size : FLAGS_top_k;
219-
float top_p_val = static_cast<float>(FLAGS_top_p);
220-
auto top_k_tensor =
221-
from_blob(&top_k_val, {}, executorch::aten::ScalarType::Long);
222-
auto top_p_tensor =
223-
from_blob(&top_p_val, {}, executorch::aten::ScalarType::Float);
224-
225201
// ---------------------------------------------------------------
226202
// Prefill
227203
// ---------------------------------------------------------------
@@ -252,8 +228,6 @@ int main(int argc, char** argv) {
252228
prefill_inputs.push_back(tokens_tensor);
253229
prefill_inputs.push_back(pos_tensor);
254230
prefill_inputs.push_back(temp_tensor);
255-
prefill_inputs.push_back(top_k_tensor);
256-
prefill_inputs.push_back(top_p_tensor);
257231

258232
auto prefill_result = module->execute(run_method, prefill_inputs);
259233
if (prefill_result.error() != Error::Ok) {
@@ -302,8 +276,6 @@ int main(int argc, char** argv) {
302276
decode_inputs.push_back(EValue(decode_tokens));
303277
decode_inputs.push_back(EValue(decode_pos));
304278
decode_inputs.push_back(EValue(temp_tensor));
305-
decode_inputs.push_back(EValue(top_k_tensor));
306-
decode_inputs.push_back(EValue(top_p_tensor));
307279

308280
auto decode_result = module->execute("decode", decode_inputs);
309281
if (decode_result.error() != Error::Ok) {

examples/models/qwen3_5_moe/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,6 @@ def forward(
631631
tokens: torch.LongTensor,
632632
input_pos: torch.LongTensor,
633633
temperature: Optional[torch.Tensor] = None,
634-
top_k: Optional[torch.Tensor] = None,
635-
top_p: Optional[torch.Tensor] = None,
636634
) -> torch.Tensor:
637635
x = self.embed_tokens(tokens)
638636
for layer in self.layers:
@@ -642,16 +640,17 @@ def forward(
642640
# logits so callers (eval, custom samplers) can inspect every
643641
# position. Otherwise apply the prefill optimization and only
644642
# materialize ``[B, V]`` for the last token.
645-
if temperature is None and top_k is None and top_p is None:
646-
return self.lm_head(x)
643+
if temperature is None:
644+
return self.lm_head(x).float() # [B, T, V] float32
647645
logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32
648646
# GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is
649647
# equivalent to drawing from softmax(logits/T) but stays entirely
650-
# on-device.
648+
# on-device. Algorithm reference:
649+
# https://huggingface.co/blog/cxdu/fastsampling
651650
# TODO(gasoonjia): once the on-device sampling stack lands, promote
652651
# ``sample`` into a shared CUDA sampling utility reusable by other
653-
# models.
654-
return sample(logits, temperature, top_k, top_p) # [B, 1]
652+
# models, and add top-k / top-p filtering support.
653+
return sample(logits, temperature) # [B, 1]
655654

656655
@staticmethod
657656
def from_hf_checkpoint(model_dir, max_seq_len=4096):

examples/models/qwen3_5_moe/sampler.py

Lines changed: 15 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""
2-
GPU-side Gumbel-max sampler with optional top-k / top-p filtering.
2+
GPU-side Gumbel-max sampler.
33
44
Self-contained sampling utility that can be imported by other models. Lives
55
in its own file so it can be reused without pulling in the heavy MoE module.
66
7-
All sampling parameters (``temperature``, ``top_k``, ``top_p``) are
8-
**runtime tensors** so a single exported program can be re-driven with
9-
different sampling configurations without re-export.
7+
``temperature`` is a runtime tensor so a single exported program can be
8+
re-driven with different sampling configurations without re-export.
9+
1010
"""
1111

1212
from typing import Optional
@@ -17,20 +17,12 @@
1717
def sample(
1818
logits: torch.Tensor,
1919
temperature: Optional[torch.Tensor] = None,
20-
top_k: Optional[torch.Tensor] = None,
21-
top_p: Optional[torch.Tensor] = None,
2220
) -> torch.Tensor:
23-
"""GPU-side Gumbel-max sampler with optional top-k / top-p filtering.
24-
25-
All three sampling knobs are *runtime* scalar tensors so the caller can
26-
change them between calls without re-exporting the graph. The Python-
27-
level ``is None`` checks are static (decided at trace time) and select
28-
which subgraph is emitted; once provided, the actual values are pure
29-
tensors and the kernels are fully data-driven.
21+
"""GPU-side Gumbel-max sampler.
3022
31-
When ``temperature``, ``top_k`` and ``top_p`` are all ``None`` (the
32-
eager / eval default), the function is a no-op and returns ``logits``
33-
unchanged — useful for callers that just want to inspect raw logits.
23+
When ``temperature`` is ``None`` (the eager / eval default) the function
24+
is a no-op and returns ``logits`` unchanged — useful for callers that
25+
just want to inspect raw logits.
3426
3527
Otherwise it draws from ``softmax(logits / temperature)`` entirely
3628
on-device using the Gumbel-max trick:
@@ -41,58 +33,23 @@ def sample(
4133
float32 logits. The contract is documented as ``[B, V]`` float32 and
4234
callers are expected to ``.float()``-cast before invoking ``sample``.
4335
36+
TODO(gasoonjia): add top-k / top-p filtering support in a follow-up PR.
37+
4438
Args:
4539
logits: ``[B, V]`` float32 logits.
4640
temperature: 0-D or 1-D float tensor (clamped to >= 1e-6 to avoid
47-
divide-by-zero). ``None`` skips temperature scaling.
48-
top_k: 0-D or 1-D int tensor — keep only the top ``k`` logits.
49-
``None`` skips top-k filtering. ``k >= V`` is also a no-op.
50-
top_p: 0-D or 1-D float tensor — nucleus threshold; keep the
51-
smallest set of logits whose cumulative softmax probability
52-
is >= ``top_p``. ``None`` (or ``>= 1.0``) disables top-p.
41+
divide-by-zero). ``None`` skips temperature scaling and the
42+
sampler returns the unmodified ``logits`` tensor.
5343
5444
Returns:
5545
``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified
56-
``logits`` tensor when all sampling parameters are ``None``.
46+
``logits`` tensor when ``temperature`` is ``None``.
5747
"""
5848
# No sampling configured — return raw logits.
59-
if temperature is None and top_k is None and top_p is None:
49+
if temperature is None:
6050
return logits
6151

62-
if temperature is not None:
63-
logits = logits / temperature.clamp(min=1e-6)
64-
65-
# Single sort handles both top-k and top-p filtering — both branches
66-
# need descending logits anyway, so we share the sort to keep the
67-
# graph small.
68-
if top_k is not None or top_p is not None:
69-
sorted_logits, sorted_idx = torch.sort(logits, dim=-1, descending=True)
70-
sorted_remove = torch.zeros_like(sorted_logits, dtype=torch.bool)
71-
72-
if top_k is not None:
73-
# Position >= k → drop. Works for any tensor k via broadcast;
74-
# k >= V naturally becomes a no-op (mask is all-False).
75-
pos = torch.arange(sorted_logits.size(-1), device=sorted_logits.device)
76-
sorted_remove = sorted_remove | (pos >= top_k.to(pos.dtype))
77-
78-
if top_p is not None:
79-
cum_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
80-
p_remove = cum_probs > top_p
81-
# Shift right by one so the highest-prob token is always kept,
82-
# even when its single-token prob already exceeds top_p.
83-
p_remove = torch.cat(
84-
[torch.zeros_like(p_remove[..., :1]), p_remove[..., :-1]],
85-
dim=-1,
86-
)
87-
sorted_remove = sorted_remove | p_remove
88-
89-
sorted_logits = torch.where(
90-
sorted_remove,
91-
torch.full_like(sorted_logits, float("-inf")),
92-
sorted_logits,
93-
)
94-
# Scatter the masked sorted logits back into original token order.
95-
logits = torch.empty_like(logits).scatter_(-1, sorted_idx, sorted_logits)
52+
logits = logits / temperature.clamp(min=1e-6)
9653

9754
# Gumbel-max sampling — equivalent to sampling from softmax(logits)
9855
# but fully on-device and CUDA-graph friendly. The 1e-20 epsilons are

0 commit comments

Comments
 (0)