Skip to content

Commit f48068d

Browse files
authored
fix: rejection sampling acceptance rate in MRv2 (#1658)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 939683e commit f48068d

4 files changed

Lines changed: 124 additions & 139 deletions

File tree

aphrodite/config/speculative.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
EagleModelTypes,
6060
NgramGPUTypes,
6161
]
62-
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]
62+
RejectionSampleMethod = Literal["standard", "synthetic"]
63+
DraftSampleMethod = Literal["greedy", "gumbel"]
6364

6465

6566
@config
@@ -179,11 +180,11 @@ class SpeculativeConfig:
179180
"""Load config for the draft model. If not specified, will use the load
180181
config from the target model."""
181182

182-
rejection_sample_method: RejectionSampleMethod = "strict"
183-
"""Whether to use strict (target and draft sampled tokens match exactly)
184-
or probabilistic rejection sampling. Both respect the target model
185-
distribution, but the latter yields a higher acceptance rate at the cost
186-
of more memory to cache draft logits."""
183+
rejection_sample_method: RejectionSampleMethod = "standard"
184+
"""The rejection sampling method to use. 'standard' uses probabilistic
185+
rejection sampling (with or without cached draft logits, controlled by
186+
draft_sample_method). 'synthetic' accepts draft tokens with a decaying
187+
probability calibrated to synthetic_acceptance_rate."""
187188

188189
synthetic_acceptance_rates: list[float] | None = None
189190
"""Per-position *unconditional* acceptance rates for synthetic rejection
@@ -233,6 +234,14 @@ def _resolve_synthetic_acceptance_rates(
233234
raise ValueError(f"synthetic_acceptance_length must be in [1, {n + 1}], got {length}.")
234235
return SpeculativeConfig._acceptance_length_to_rates(length, n)
235236

237+
draft_sample_method: DraftSampleMethod = "greedy"
238+
"""How the draft model samples tokens. 'greedy' always picks the argmax
239+
token, and the draft probabilities are treated as one-hot during rejection
240+
sampling. 'gumbel' adds Gumbel noise for stochastic sampling, and the full
241+
draft logits are used for the probability ratio test during rejection
242+
sampling. This comes at the cost of additional GPU memory usage. This
243+
parameter currently only applies to Model Runner V2."""
244+
236245
def compute_hash(self) -> str:
237246
"""
238247
WARNING: Whenever a new field is added to this config,

aphrodite/v1/worker/gpu/spec_decode/eagle/speculator.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, aphrodite_config: AphroditeConfig, device: torch.device):
8787
self.inputs_embeds = torch.zeros(self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device)
8888

8989
self.draft_logits: torch.Tensor | None = None
90-
if self.speculative_config.rejection_sample_method == "probabilistic":
90+
if self.speculative_config.draft_sample_method == "gumbel":
9191
self.draft_logits = torch.zeros(
9292
self.max_num_reqs,
9393
self.num_speculative_steps,
@@ -204,6 +204,28 @@ def run_model(
204204
last_hidden_states, hidden_states = ret_hidden_states
205205
return last_hidden_states, hidden_states
206206

207+
def _sample_draft(
208+
self,
209+
logits: torch.Tensor,
210+
idx_mapping: torch.Tensor,
211+
pos: torch.Tensor,
212+
step: int,
213+
) -> torch.Tensor:
214+
if self.draft_logits is not None:
215+
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
216+
# used for draft and target sampling.
217+
return gumbel_sample(
218+
logits,
219+
idx_mapping,
220+
self.temperature,
221+
self.seeds,
222+
pos + 1,
223+
apply_temperature=True,
224+
processed_logits_out=self.draft_logits[:, step],
225+
)
226+
else:
227+
return logits.argmax(dim=-1)
228+
207229
def prefill(
208230
self,
209231
num_reqs: int,
@@ -229,16 +251,11 @@ def prefill(
229251
sample_hidden_states = last_hidden_states[last_token_indices]
230252
logits = self.model.compute_logits(sample_hidden_states)
231253

232-
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
233-
# used for draft and target sampling.
234-
self.draft_tokens[:num_reqs, 0] = gumbel_sample(
254+
self.draft_tokens[:num_reqs, 0] = self._sample_draft(
235255
logits,
236256
idx_mapping,
237-
self.temperature,
238-
self.seeds,
239-
pos + 1,
240-
apply_temperature=True,
241-
processed_logits_out=self.draft_logits[:, 0] if self.draft_logits is not None else None,
257+
pos,
258+
step=0,
242259
)
243260
self.hidden_states[:num_reqs] = hidden_states[last_token_indices]
244261
self.input_buffers.positions[:num_reqs] = pos
@@ -268,16 +285,11 @@ def generate_draft(
268285
hidden_states = hidden_states[:num_reqs]
269286
logits = self.model.compute_logits(last_hidden_states)
270287

271-
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
272-
# used for draft and target sampling.
273-
draft_tokens = gumbel_sample(
288+
draft_tokens = self._sample_draft(
274289
logits,
275290
idx_mapping,
276-
self.temperature,
277-
self.seeds,
278-
pos + 1,
279-
apply_temperature=True,
280-
processed_logits_out=self.draft_logits[:, step] if self.draft_logits is not None else None,
291+
pos,
292+
step=step,
281293
)
282294
self.draft_tokens[:num_reqs, step] = draft_tokens
283295

aphrodite/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _compute_global_lse(
4545

4646

4747
@triton.jit
48-
def _compute_block_max_and_sumexp_kernel(
48+
def _compute_block_stats_kernel(
4949
# [num_logits, num_blocks]
5050
target_local_argmax_ptr,
5151
target_local_argmax_stride,
@@ -77,6 +77,7 @@ def _compute_block_max_and_sumexp_kernel(
7777
vocab_size,
7878
num_speculative_steps,
7979
BLOCK_SIZE: tl.constexpr,
80+
HAS_DRAFT_LOGITS: tl.constexpr,
8081
):
8182
logit_idx = tl.program_id(0)
8283
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
@@ -110,24 +111,6 @@ def _compute_block_max_and_sumexp_kernel(
110111
value,
111112
)
112113
else:
113-
# Get local draft max and summed exponentials.
114-
draft_logits = tl.load(
115-
draft_logits_ptr
116-
+ req_state_idx * draft_logits_stride_0
117-
+ draft_step_idx * draft_logits_stride_1
118-
+ block_offsets,
119-
mask=mask,
120-
other=float("-inf"),
121-
).to(tl.float32)
122-
draft_max, draft_sumexp = _compute_block_max_and_sumexp(draft_logits)
123-
tl.store(
124-
draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx,
125-
draft_max,
126-
)
127-
tl.store(
128-
draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx,
129-
draft_sumexp,
130-
)
131114
# Get local target max and summed exponentials.
132115
target_logits = tl.load(
133116
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
@@ -143,6 +126,25 @@ def _compute_block_max_and_sumexp_kernel(
143126
target_local_sumexp_ptr + logit_idx * target_local_sumexp_stride + block_idx,
144127
target_sumexp,
145128
)
129+
if HAS_DRAFT_LOGITS:
130+
# Get local draft max and summed exponentials.
131+
draft_logits = tl.load(
132+
draft_logits_ptr
133+
+ req_state_idx * draft_logits_stride_0
134+
+ draft_step_idx * draft_logits_stride_1
135+
+ block_offsets,
136+
mask=mask,
137+
other=float("-inf"),
138+
).to(tl.float32)
139+
draft_max, draft_sumexp = _compute_block_max_and_sumexp(draft_logits)
140+
tl.store(
141+
draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx,
142+
draft_max,
143+
)
144+
tl.store(
145+
draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx,
146+
draft_sumexp,
147+
)
146148

147149

148150
@triton.jit
@@ -192,6 +194,7 @@ def _probabilistic_rejection_kernel(
192194
pos_ptr,
193195
vocab_num_blocks,
194196
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
197+
HAS_DRAFT_LOGITS: tl.constexpr,
195198
):
196199
req_idx = tl.program_id(0)
197200
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
@@ -230,9 +233,6 @@ def _probabilistic_rejection_kernel(
230233
target_logit = tl.load(target_logits_ptr + logit_idx * target_logits_stride + draft_sampled).to(
231234
tl.float32
232235
)
233-
draft_logit = tl.load(
234-
draft_logits_ptr + req_state_idx * draft_logits_stride_0 + i * draft_logits_stride_1 + draft_sampled
235-
).to(tl.float32)
236236
target_lse = _compute_global_lse(
237237
target_local_max_ptr,
238238
target_local_max_stride,
@@ -242,19 +242,29 @@ def _probabilistic_rejection_kernel(
242242
vocab_num_blocks,
243243
PADDED_VOCAB_NUM_BLOCKS,
244244
)
245-
draft_lse = _compute_global_lse(
246-
draft_local_max_ptr,
247-
draft_local_max_stride,
248-
draft_local_sumexp_ptr,
249-
draft_local_sumexp_stride,
250-
logit_idx,
251-
vocab_num_blocks,
252-
PADDED_VOCAB_NUM_BLOCKS,
253-
)
254245
target_log_prob = target_logit - target_lse
255-
draft_log_prob = draft_logit - draft_lse
256246
pos = tl.load(pos_ptr + logit_idx)
257247
u = tl_rand64(seed, pos, includes_zero=False)
248+
if HAS_DRAFT_LOGITS:
249+
draft_logit = tl.load(
250+
draft_logits_ptr
251+
+ req_state_idx * draft_logits_stride_0
252+
+ i * draft_logits_stride_1
253+
+ draft_sampled
254+
).to(tl.float32)
255+
draft_lse = _compute_global_lse(
256+
draft_local_max_ptr,
257+
draft_local_max_stride,
258+
draft_local_sumexp_ptr,
259+
draft_local_sumexp_stride,
260+
logit_idx,
261+
vocab_num_blocks,
262+
PADDED_VOCAB_NUM_BLOCKS,
263+
)
264+
draft_log_prob = draft_logit - draft_lse
265+
else:
266+
# One-hot draft: q(draft_token) = 1, log_q = 0.
267+
draft_log_prob = 0
258268
# Probability ratio test: p(x) > u * q(x)
259269
# Equivalent log form: log_p(x) > log(u) + log_q(x)
260270
accepted &= target_log_prob > tl.log(u) + draft_log_prob
@@ -290,6 +300,8 @@ def _resample_kernel(
290300
cu_num_logits_ptr,
291301
# [num_logits]
292302
expanded_idx_mapping_ptr,
303+
# [num_logits]
304+
draft_sampled_ptr,
293305
# [max_num_reqs]
294306
temp_ptr,
295307
# [max_num_reqs]
@@ -298,6 +310,7 @@ def _resample_kernel(
298310
pos_ptr,
299311
vocab_size,
300312
BLOCK_SIZE: tl.constexpr,
313+
HAS_DRAFT_LOGITS: tl.constexpr,
301314
):
302315
req_idx = tl.program_id(0)
303316
resample_idx = tl.load(rejected_step_ptr + req_idx)
@@ -316,22 +329,17 @@ def _resample_kernel(
316329
block_idx = tl.program_id(1)
317330
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
318331
mask = block < vocab_size
332+
target_logits = tl.load(
333+
target_logits_ptr + resample_token_idx * target_logits_stride + block,
334+
mask=mask,
335+
other=float("-inf"),
336+
).to(tl.float32)
319337

320-
# Compute the residual logits to resample the rejected token
321-
# from. In the case of no rejections (bonus token), we directly
322-
# use the target logits.
338+
# Compute the residual logits to resample the rejected token from.
323339
if is_bonus:
324-
residual_logits = tl.load(
325-
target_logits_ptr + resample_token_idx * target_logits_stride + block,
326-
mask=mask,
327-
other=float("-inf"),
328-
).to(tl.float32)
329-
else:
330-
target_logits = tl.load(
331-
target_logits_ptr + resample_token_idx * target_logits_stride + block,
332-
mask=mask,
333-
other=float("-inf"),
334-
).to(tl.float32)
340+
# Bonus token (no rejections). Directly use the target logits.
341+
residual_logits = target_logits
342+
elif HAS_DRAFT_LOGITS:
335343
draft_logits = tl.load(
336344
draft_logits_ptr + req_state_idx * draft_logits_stride_0 + resample_idx * draft_logits_stride_1 + block,
337345
mask=mask,
@@ -351,6 +359,15 @@ def _resample_kernel(
351359
target_log_probs + tl.log(1 - ratio),
352360
float("-inf"),
353361
).to(tl.float32)
362+
else:
363+
# One-hot draft. The residual is just the target distribution with
364+
# the rejected draft token probability zeroed out.
365+
rejected_draft_token = tl.load(draft_sampled_ptr + resample_token_idx + 1)
366+
residual_logits = tl.where(
367+
block != rejected_draft_token,
368+
target_logits,
369+
float("-inf"),
370+
).to(tl.float32)
354371

355372
# Resample the rejected/bonus token.
356373
value, idx = gumbel_block_argmax(
@@ -438,7 +455,7 @@ def probabilistic_rejection_sample(
438455
# [num_logits, V]
439456
target_logits: torch.Tensor,
440457
# [max_num_reqs, num_speculative_steps, V]
441-
draft_logits: torch.Tensor,
458+
draft_logits: torch.Tensor | None,
442459
# [num_logits]
443460
draft_sampled: torch.Tensor,
444461
# [num_reqs + 1]
@@ -459,9 +476,17 @@ def probabilistic_rejection_sample(
459476
) -> tuple[torch.Tensor, torch.Tensor]:
460477
num_reqs = cu_num_logits.shape[0] - 1
461478
num_logits, vocab_size = target_logits.shape
479+
has_draft_logits = draft_logits is not None
480+
481+
if draft_logits is None:
482+
# When draft_logits is None, create a dummy tensor so that Triton
483+
# kernel signatures receive valid pointers/strides. The kernels
484+
# will never read from it when HAS_DRAFT_LOGITS=False.
485+
draft_logits = target_logits.new_empty(1, 1, 1)
462486

463-
# Gather draft logits, compute target argmax for greedy, and
464-
# compute per-block LSE and max for non-greedy requests.
487+
# Compute the block-level logits stats, such as target argmax
488+
# (for greedy requests), and target max + softmax exponential
489+
# (for non-greedy requests).
465490
VOCAB_BLOCK_SIZE = 8192
466491
vocab_num_blocks = triton.cdiv(vocab_size, VOCAB_BLOCK_SIZE)
467492
padded_vocab_num_blocks = triton.next_power_of_2(vocab_num_blocks)
@@ -470,7 +495,7 @@ def probabilistic_rejection_sample(
470495
target_local_sumexp = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32)
471496
draft_local_max = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32)
472497
draft_local_sumexp = target_logits.new_empty(num_logits, vocab_num_blocks, dtype=torch.float32)
473-
_compute_block_max_and_sumexp_kernel[(num_logits, vocab_num_blocks)](
498+
_compute_block_stats_kernel[(num_logits, vocab_num_blocks)](
474499
target_local_argmax,
475500
target_local_argmax.stride(0),
476501
target_local_max,
@@ -492,6 +517,7 @@ def probabilistic_rejection_sample(
492517
vocab_size,
493518
num_speculative_steps,
494519
BLOCK_SIZE=VOCAB_BLOCK_SIZE,
520+
HAS_DRAFT_LOGITS=has_draft_logits,
495521
)
496522

497523
# Sample up until the first rejected/bonus token, and store
@@ -529,6 +555,7 @@ def probabilistic_rejection_sample(
529555
pos,
530556
vocab_num_blocks,
531557
PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks,
558+
HAS_DRAFT_LOGITS=has_draft_logits,
532559
num_warps=1,
533560
)
534561

@@ -553,11 +580,13 @@ def probabilistic_rejection_sample(
553580
num_sampled,
554581
cu_num_logits,
555582
expanded_idx_mapping,
583+
draft_sampled,
556584
temperature,
557585
seed,
558586
pos,
559587
vocab_size,
560588
BLOCK_SIZE=RESAMPLE_BLOCK_SIZE,
589+
HAS_DRAFT_LOGITS=has_draft_logits,
561590
)
562591

563592
# Insert the resampled tokens into the output sampled.

0 commit comments

Comments
 (0)