@@ -45,7 +45,7 @@ def _compute_global_lse(
4545
4646
4747@triton .jit
48- def _compute_block_max_and_sumexp_kernel (
48+ def _compute_block_stats_kernel (
4949 # [num_logits, num_blocks]
5050 target_local_argmax_ptr ,
5151 target_local_argmax_stride ,
@@ -77,6 +77,7 @@ def _compute_block_max_and_sumexp_kernel(
7777 vocab_size ,
7878 num_speculative_steps ,
7979 BLOCK_SIZE : tl .constexpr ,
80+ HAS_DRAFT_LOGITS : tl .constexpr ,
8081):
8182 logit_idx = tl .program_id (0 )
8283 draft_step_idx = tl .load (expanded_local_pos_ptr + logit_idx )
@@ -110,24 +111,6 @@ def _compute_block_max_and_sumexp_kernel(
110111 value ,
111112 )
112113 else :
113- # Get local draft max and summed exponentials.
114- draft_logits = tl .load (
115- draft_logits_ptr
116- + req_state_idx * draft_logits_stride_0
117- + draft_step_idx * draft_logits_stride_1
118- + block_offsets ,
119- mask = mask ,
120- other = float ("-inf" ),
121- ).to (tl .float32 )
122- draft_max , draft_sumexp = _compute_block_max_and_sumexp (draft_logits )
123- tl .store (
124- draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx ,
125- draft_max ,
126- )
127- tl .store (
128- draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx ,
129- draft_sumexp ,
130- )
131114 # Get local target max and summed exponentials.
132115 target_logits = tl .load (
133116 target_logits_ptr + logit_idx * target_logits_stride + block_offsets ,
@@ -143,6 +126,25 @@ def _compute_block_max_and_sumexp_kernel(
143126 target_local_sumexp_ptr + logit_idx * target_local_sumexp_stride + block_idx ,
144127 target_sumexp ,
145128 )
129+ if HAS_DRAFT_LOGITS :
130+ # Get local draft max and summed exponentials.
131+ draft_logits = tl .load (
132+ draft_logits_ptr
133+ + req_state_idx * draft_logits_stride_0
134+ + draft_step_idx * draft_logits_stride_1
135+ + block_offsets ,
136+ mask = mask ,
137+ other = float ("-inf" ),
138+ ).to (tl .float32 )
139+ draft_max , draft_sumexp = _compute_block_max_and_sumexp (draft_logits )
140+ tl .store (
141+ draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx ,
142+ draft_max ,
143+ )
144+ tl .store (
145+ draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx ,
146+ draft_sumexp ,
147+ )
146148
147149
148150@triton .jit
@@ -192,6 +194,7 @@ def _probabilistic_rejection_kernel(
192194 pos_ptr ,
193195 vocab_num_blocks ,
194196 PADDED_VOCAB_NUM_BLOCKS : tl .constexpr ,
197+ HAS_DRAFT_LOGITS : tl .constexpr ,
195198):
196199 req_idx = tl .program_id (0 )
197200 req_state_idx = tl .load (idx_mapping_ptr + req_idx )
@@ -230,9 +233,6 @@ def _probabilistic_rejection_kernel(
230233 target_logit = tl .load (target_logits_ptr + logit_idx * target_logits_stride + draft_sampled ).to (
231234 tl .float32
232235 )
233- draft_logit = tl .load (
234- draft_logits_ptr + req_state_idx * draft_logits_stride_0 + i * draft_logits_stride_1 + draft_sampled
235- ).to (tl .float32 )
236236 target_lse = _compute_global_lse (
237237 target_local_max_ptr ,
238238 target_local_max_stride ,
@@ -242,19 +242,29 @@ def _probabilistic_rejection_kernel(
242242 vocab_num_blocks ,
243243 PADDED_VOCAB_NUM_BLOCKS ,
244244 )
245- draft_lse = _compute_global_lse (
246- draft_local_max_ptr ,
247- draft_local_max_stride ,
248- draft_local_sumexp_ptr ,
249- draft_local_sumexp_stride ,
250- logit_idx ,
251- vocab_num_blocks ,
252- PADDED_VOCAB_NUM_BLOCKS ,
253- )
254245 target_log_prob = target_logit - target_lse
255- draft_log_prob = draft_logit - draft_lse
256246 pos = tl .load (pos_ptr + logit_idx )
257247 u = tl_rand64 (seed , pos , includes_zero = False )
248+ if HAS_DRAFT_LOGITS :
249+ draft_logit = tl .load (
250+ draft_logits_ptr
251+ + req_state_idx * draft_logits_stride_0
252+ + i * draft_logits_stride_1
253+ + draft_sampled
254+ ).to (tl .float32 )
255+ draft_lse = _compute_global_lse (
256+ draft_local_max_ptr ,
257+ draft_local_max_stride ,
258+ draft_local_sumexp_ptr ,
259+ draft_local_sumexp_stride ,
260+ logit_idx ,
261+ vocab_num_blocks ,
262+ PADDED_VOCAB_NUM_BLOCKS ,
263+ )
264+ draft_log_prob = draft_logit - draft_lse
265+ else :
266+ # One-hot draft: q(draft_token) = 1, log_q = 0.
267+ draft_log_prob = 0
258268 # Probability ratio test: p(x) > u * q(x)
259269 # Equivalent log form: log_p(x) > log(u) + log_q(x)
260270 accepted &= target_log_prob > tl .log (u ) + draft_log_prob
@@ -290,6 +300,8 @@ def _resample_kernel(
290300 cu_num_logits_ptr ,
291301 # [num_logits]
292302 expanded_idx_mapping_ptr ,
303+ # [num_logits]
304+ draft_sampled_ptr ,
293305 # [max_num_reqs]
294306 temp_ptr ,
295307 # [max_num_reqs]
@@ -298,6 +310,7 @@ def _resample_kernel(
298310 pos_ptr ,
299311 vocab_size ,
300312 BLOCK_SIZE : tl .constexpr ,
313+ HAS_DRAFT_LOGITS : tl .constexpr ,
301314):
302315 req_idx = tl .program_id (0 )
303316 resample_idx = tl .load (rejected_step_ptr + req_idx )
@@ -316,22 +329,17 @@ def _resample_kernel(
316329 block_idx = tl .program_id (1 )
317330 block = block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
318331 mask = block < vocab_size
332+ target_logits = tl .load (
333+ target_logits_ptr + resample_token_idx * target_logits_stride + block ,
334+ mask = mask ,
335+ other = float ("-inf" ),
336+ ).to (tl .float32 )
319337
320- # Compute the residual logits to resample the rejected token
321- # from. In the case of no rejections (bonus token), we directly
322- # use the target logits.
338+ # Compute the residual logits to resample the rejected token from.
323339 if is_bonus :
324- residual_logits = tl .load (
325- target_logits_ptr + resample_token_idx * target_logits_stride + block ,
326- mask = mask ,
327- other = float ("-inf" ),
328- ).to (tl .float32 )
329- else :
330- target_logits = tl .load (
331- target_logits_ptr + resample_token_idx * target_logits_stride + block ,
332- mask = mask ,
333- other = float ("-inf" ),
334- ).to (tl .float32 )
340+ # Bonus token (no rejections). Directly use the target logits.
341+ residual_logits = target_logits
342+ elif HAS_DRAFT_LOGITS :
335343 draft_logits = tl .load (
336344 draft_logits_ptr + req_state_idx * draft_logits_stride_0 + resample_idx * draft_logits_stride_1 + block ,
337345 mask = mask ,
@@ -351,6 +359,15 @@ def _resample_kernel(
351359 target_log_probs + tl .log (1 - ratio ),
352360 float ("-inf" ),
353361 ).to (tl .float32 )
362+ else :
363+ # One-hot draft. The residual is just the target distribution with
364+ # the rejected draft token probability zeroed out.
365+ rejected_draft_token = tl .load (draft_sampled_ptr + resample_token_idx + 1 )
366+ residual_logits = tl .where (
367+ block != rejected_draft_token ,
368+ target_logits ,
369+ float ("-inf" ),
370+ ).to (tl .float32 )
354371
355372 # Resample the rejected/bonus token.
356373 value , idx = gumbel_block_argmax (
@@ -438,7 +455,7 @@ def probabilistic_rejection_sample(
438455 # [num_logits, V]
439456 target_logits : torch .Tensor ,
440457 # [max_num_reqs, num_speculative_steps, V]
441- draft_logits : torch .Tensor ,
458+ draft_logits : torch .Tensor | None ,
442459 # [num_logits]
443460 draft_sampled : torch .Tensor ,
444461 # [num_reqs + 1]
@@ -459,9 +476,17 @@ def probabilistic_rejection_sample(
459476) -> tuple [torch .Tensor , torch .Tensor ]:
460477 num_reqs = cu_num_logits .shape [0 ] - 1
461478 num_logits , vocab_size = target_logits .shape
479+ has_draft_logits = draft_logits is not None
480+
481+ if draft_logits is None :
482+ # When draft_logits is None, create a dummy tensor so that Triton
483+ # kernel signatures receive valid pointers/strides. The kernels
484+ # will never read from it when HAS_DRAFT_LOGITS=False.
485+ draft_logits = target_logits .new_empty (1 , 1 , 1 )
462486
463- # Gather draft logits, compute target argmax for greedy, and
464- # compute per-block LSE and max for non-greedy requests.
487+ # Compute the block-level logits stats, such as target argmax
488+ # (for greedy requests), and target max + softmax exponential
489+ # (for non-greedy requests).
465490 VOCAB_BLOCK_SIZE = 8192
466491 vocab_num_blocks = triton .cdiv (vocab_size , VOCAB_BLOCK_SIZE )
467492 padded_vocab_num_blocks = triton .next_power_of_2 (vocab_num_blocks )
@@ -470,7 +495,7 @@ def probabilistic_rejection_sample(
470495 target_local_sumexp = target_logits .new_empty (num_logits , vocab_num_blocks , dtype = torch .float32 )
471496 draft_local_max = target_logits .new_empty (num_logits , vocab_num_blocks , dtype = torch .float32 )
472497 draft_local_sumexp = target_logits .new_empty (num_logits , vocab_num_blocks , dtype = torch .float32 )
473- _compute_block_max_and_sumexp_kernel [(num_logits , vocab_num_blocks )](
498+ _compute_block_stats_kernel [(num_logits , vocab_num_blocks )](
474499 target_local_argmax ,
475500 target_local_argmax .stride (0 ),
476501 target_local_max ,
@@ -492,6 +517,7 @@ def probabilistic_rejection_sample(
492517 vocab_size ,
493518 num_speculative_steps ,
494519 BLOCK_SIZE = VOCAB_BLOCK_SIZE ,
520+ HAS_DRAFT_LOGITS = has_draft_logits ,
495521 )
496522
497523 # Sample up until the first rejected/bonus token, and store
@@ -529,6 +555,7 @@ def probabilistic_rejection_sample(
529555 pos ,
530556 vocab_num_blocks ,
531557 PADDED_VOCAB_NUM_BLOCKS = padded_vocab_num_blocks ,
558+ HAS_DRAFT_LOGITS = has_draft_logits ,
532559 num_warps = 1 ,
533560 )
534561
@@ -553,11 +580,13 @@ def probabilistic_rejection_sample(
553580 num_sampled ,
554581 cu_num_logits ,
555582 expanded_idx_mapping ,
583+ draft_sampled ,
556584 temperature ,
557585 seed ,
558586 pos ,
559587 vocab_size ,
560588 BLOCK_SIZE = RESAMPLE_BLOCK_SIZE ,
589+ HAS_DRAFT_LOGITS = has_draft_logits ,
561590 )
562591
563592 # Insert the resampled tokens into the output sampled.
0 commit comments