Skip to content

Commit ae4d41d

Browse files
authored
chore: massively improve DRY performance (#1634)
* chore: massively improve DRY performance Signed-off-by: AlpinDale <alpindale@gmail.com> * chore: refactor dry_scan_penalties_cpu for template support and type safety Signed-off-by: AlpinDale <alpindale@gmail.com> * fix: off-by-one issue in native DRY kernel Signed-off-by: AlpinDale <alpindale@gmail.com> --------- Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent fbd30d1 commit ae4d41d

11 files changed

Lines changed: 872 additions & 95 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ set(APHRODITE_EXT_SRC
290290
"csrc/quantization/activation_kernels.cu"
291291
"csrc/cuda_utils_kernels.cu"
292292
"csrc/all_reduce/custom_all_reduce.cu"
293+
"csrc/cpu/dry.cpp"
293294
"csrc/torch_bindings.cpp")
294295

295296
if(APHRODITE_GPU_LANG STREQUAL "CUDA")

aphrodite/v1/sample/metadata.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,17 @@ class SamplingMetadata:
105105

106106
# Speculative token ids
107107
spec_token_ids: list[list[int]] | None = None
108+
109+
# Cached padded token-history tensor for GPU-side sampler ops.
110+
output_token_ids_tensor: torch.Tensor | None = None
111+
token_history_ids: torch.Tensor | None = None
112+
token_history_lens: torch.Tensor | None = None
113+
token_history_ids_cpu: torch.Tensor | None = None
114+
token_history_lens_cpu: torch.Tensor | None = None
115+
dry_multiplier_cpu: torch.Tensor | None = None
116+
dry_allowed_length_cpu: torch.Tensor | None = None
117+
dry_sequence_breaker_ids_cpu: torch.Tensor | None = None
118+
dry_ranges_cpu: torch.Tensor | None = None
119+
dry_max_ngram_cpu: torch.Tensor | None = None
120+
dry_max_occurrences_cpu: torch.Tensor | None = None
121+
dry_early_exit_match_len_cpu: torch.Tensor | None = None

aphrodite/v1/sample/ops/__init__.py

Lines changed: 99 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import torch
44

5-
from aphrodite.utils.platform_utils import is_pin_memory_available
6-
from aphrodite.utils.torch_utils import make_tensor_with_pad
75
from aphrodite.v1.sample.metadata import SamplingMetadata
86
from aphrodite.v1.sample.ops.bad_words import apply_bad_words
9-
from aphrodite.v1.sample.ops.dry import apply_all_dry
7+
from aphrodite.v1.sample.ops.dry import (
8+
DRY_STATE_KEY,
9+
DryRequestState,
10+
_compute_dry_penalties,
11+
_get_or_rebuild_dry_state,
12+
)
1013
from aphrodite.v1.sample.ops.epsilon_cutoff import epsilon_cutoff
1114
from aphrodite.v1.sample.ops.eta_cutoff import eta_cutoff
1215
from aphrodite.v1.sample.ops.min_p import min_p
@@ -93,40 +96,102 @@ def apply_dry(
9396
sampling_metadata: SamplingMetadata,
9497
) -> torch.Tensor:
9598
"""Apply DRY sampling to the logits."""
96-
if sampling_metadata.dry_multiplier is not None and sampling_metadata.prompt_token_ids is not None:
97-
# Convert output_token_ids to tensor
98-
_, vocab_size = logits.shape
99-
output_tokens_t = make_tensor_with_pad(
100-
sampling_metadata.output_token_ids,
101-
pad=vocab_size,
102-
device="cpu",
103-
dtype=torch.int64,
104-
pin_memory=is_pin_memory_available(),
105-
).to(logits.device, non_blocking=True)
106-
107-
# Ensure all required tensors are not None
99+
if (
100+
sampling_metadata.dry_multiplier is not None
101+
and sampling_metadata.dry_base is not None
102+
and sampling_metadata.dry_allowed_length is not None
103+
and sampling_metadata.dry_sequence_breaker_ids is not None
104+
and sampling_metadata.dry_ranges is not None
105+
and sampling_metadata.dry_max_ngram is not None
106+
and sampling_metadata.dry_max_occurrences is not None
107+
and sampling_metadata.dry_early_exit_match_len is not None
108+
):
108109
if (
109-
sampling_metadata.dry_base is not None
110-
and sampling_metadata.dry_allowed_length is not None
111-
and sampling_metadata.dry_sequence_breaker_ids is not None
112-
and sampling_metadata.dry_ranges is not None
113-
and sampling_metadata.dry_max_ngram is not None
114-
and sampling_metadata.dry_max_occurrences is not None
115-
and sampling_metadata.dry_early_exit_match_len is not None
110+
sampling_metadata.token_history_ids_cpu is not None
111+
and sampling_metadata.token_history_lens_cpu is not None
112+
and sampling_metadata.dry_multiplier_cpu is not None
113+
and sampling_metadata.dry_allowed_length_cpu is not None
114+
and sampling_metadata.dry_sequence_breaker_ids_cpu is not None
115+
and sampling_metadata.dry_ranges_cpu is not None
116+
and sampling_metadata.dry_max_ngram_cpu is not None
117+
and sampling_metadata.dry_max_occurrences_cpu is not None
118+
and sampling_metadata.dry_early_exit_match_len_cpu is not None
116119
):
117-
logits = apply_all_dry(
118-
logits,
119-
sampling_metadata.prompt_token_ids,
120-
output_tokens_t,
121-
sampling_metadata.dry_multiplier,
122-
sampling_metadata.dry_base,
123-
sampling_metadata.dry_allowed_length,
124-
sampling_metadata.dry_sequence_breaker_ids,
125-
sampling_metadata.dry_ranges,
126-
sampling_metadata.dry_max_ngram,
127-
sampling_metadata.dry_max_occurrences,
128-
sampling_metadata.dry_early_exit_match_len,
120+
row_indexes_cpu, token_indexes_cpu, match_lens_cpu = torch.ops._C.dry_scan_penalties(
121+
sampling_metadata.token_history_ids_cpu,
122+
sampling_metadata.token_history_lens_cpu,
123+
sampling_metadata.dry_multiplier_cpu,
124+
sampling_metadata.dry_allowed_length_cpu,
125+
sampling_metadata.dry_sequence_breaker_ids_cpu,
126+
sampling_metadata.dry_ranges_cpu,
127+
sampling_metadata.dry_max_ngram_cpu,
128+
sampling_metadata.dry_max_occurrences_cpu,
129+
sampling_metadata.dry_early_exit_match_len_cpu,
130+
logits.size(-1),
131+
)
132+
if row_indexes_cpu.numel():
133+
row_indexes_gpu = row_indexes_cpu.to(device=logits.device, non_blocking=True)
134+
token_indexes_gpu = token_indexes_cpu.to(device=logits.device, non_blocking=True)
135+
match_lens_gpu = match_lens_cpu.to(device=logits.device, dtype=logits.dtype, non_blocking=True)
136+
allowed_lengths_t = sampling_metadata.dry_allowed_length[row_indexes_gpu].to(logits.dtype)
137+
scales = sampling_metadata.dry_base[row_indexes_gpu] ** (match_lens_gpu - allowed_lengths_t)
138+
logits[row_indexes_gpu, token_indexes_gpu] -= (
139+
sampling_metadata.dry_multiplier[row_indexes_gpu] * scales
140+
)
141+
return logits
142+
143+
row_indexes: list[int] = []
144+
token_indexes: list[int] = []
145+
match_lens: list[int] = []
146+
147+
for irow_t in sampling_metadata.dry_multiplier.nonzero(as_tuple=True)[0]:
148+
irow = irow_t.item()
149+
persistent_entry = sampling_metadata.persistent_data.setdefault(irow, {})
150+
dry_state = persistent_entry.get(DRY_STATE_KEY)
151+
breaker_ids = (
152+
sampling_metadata.dry_sequence_breaker_ids[irow]
153+
.masked_select(sampling_metadata.dry_sequence_breaker_ids[irow] < logits.size(-1))
154+
.tolist()
155+
)
156+
expected_len = len(sampling_metadata.output_token_ids[irow])
157+
if sampling_metadata.prompt_token_ids is not None:
158+
expected_len += (sampling_metadata.prompt_token_ids[irow] < logits.size(-1)).sum().item()
159+
if (
160+
not isinstance(dry_state, DryRequestState)
161+
or dry_state.breaker_ids != frozenset(breaker_ids)
162+
or len(dry_state.history) != expected_len
163+
):
164+
dry_state = _get_or_rebuild_dry_state(
165+
persistent_entry,
166+
None
167+
if sampling_metadata.prompt_token_ids is None
168+
else sampling_metadata.prompt_token_ids[irow]
169+
.masked_select(sampling_metadata.prompt_token_ids[irow] < logits.size(-1))
170+
.tolist(),
171+
sampling_metadata.output_token_ids[irow],
172+
breaker_ids,
173+
)
174+
penalties = _compute_dry_penalties(
175+
dry_state,
176+
allowed_length=sampling_metadata.dry_allowed_length[irow].item(),
177+
range_limit=sampling_metadata.dry_ranges[irow].item(),
178+
max_ngram=sampling_metadata.dry_max_ngram[irow].item(),
179+
max_occurrences=sampling_metadata.dry_max_occurrences[irow].item(),
180+
early_exit_match_len=sampling_metadata.dry_early_exit_match_len[irow].item(),
129181
)
182+
for token_id, match_len in penalties.items():
183+
row_indexes.append(irow)
184+
token_indexes.append(token_id)
185+
match_lens.append(match_len)
186+
187+
if row_indexes:
188+
row_indexes_t = torch.tensor(row_indexes, device=logits.device, dtype=torch.long)
189+
token_indexes_t = torch.tensor(token_indexes, device=logits.device, dtype=torch.long)
190+
match_lens_t = torch.tensor(match_lens, device=logits.device, dtype=logits.dtype)
191+
allowed_lengths_t = sampling_metadata.dry_allowed_length[row_indexes_t].to(logits.dtype)
192+
scales = sampling_metadata.dry_base[row_indexes_t] ** (match_lens_t - allowed_lengths_t)
193+
logits[row_indexes_t, token_indexes_t] -= sampling_metadata.dry_multiplier[row_indexes_t] * scales
194+
return logits
130195
return logits
131196

132197
def apply_no_repeat_ngram(

0 commit comments

Comments
 (0)