|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import torch |
4 | 4 |
|
5 | | -from aphrodite.utils.platform_utils import is_pin_memory_available |
6 | | -from aphrodite.utils.torch_utils import make_tensor_with_pad |
7 | 5 | from aphrodite.v1.sample.metadata import SamplingMetadata |
8 | 6 | from aphrodite.v1.sample.ops.bad_words import apply_bad_words |
9 | | -from aphrodite.v1.sample.ops.dry import apply_all_dry |
| 7 | +from aphrodite.v1.sample.ops.dry import ( |
| 8 | + DRY_STATE_KEY, |
| 9 | + DryRequestState, |
| 10 | + _compute_dry_penalties, |
| 11 | + _get_or_rebuild_dry_state, |
| 12 | +) |
10 | 13 | from aphrodite.v1.sample.ops.epsilon_cutoff import epsilon_cutoff |
11 | 14 | from aphrodite.v1.sample.ops.eta_cutoff import eta_cutoff |
12 | 15 | from aphrodite.v1.sample.ops.min_p import min_p |
@@ -93,40 +96,102 @@ def apply_dry( |
93 | 96 | sampling_metadata: SamplingMetadata, |
94 | 97 | ) -> torch.Tensor: |
95 | 98 | """Apply DRY sampling to the logits.""" |
96 | | - if sampling_metadata.dry_multiplier is not None and sampling_metadata.prompt_token_ids is not None: |
97 | | - # Convert output_token_ids to tensor |
98 | | - _, vocab_size = logits.shape |
99 | | - output_tokens_t = make_tensor_with_pad( |
100 | | - sampling_metadata.output_token_ids, |
101 | | - pad=vocab_size, |
102 | | - device="cpu", |
103 | | - dtype=torch.int64, |
104 | | - pin_memory=is_pin_memory_available(), |
105 | | - ).to(logits.device, non_blocking=True) |
106 | | - |
107 | | - # Ensure all required tensors are not None |
| 99 | + if ( |
| 100 | + sampling_metadata.dry_multiplier is not None |
| 101 | + and sampling_metadata.dry_base is not None |
| 102 | + and sampling_metadata.dry_allowed_length is not None |
| 103 | + and sampling_metadata.dry_sequence_breaker_ids is not None |
| 104 | + and sampling_metadata.dry_ranges is not None |
| 105 | + and sampling_metadata.dry_max_ngram is not None |
| 106 | + and sampling_metadata.dry_max_occurrences is not None |
| 107 | + and sampling_metadata.dry_early_exit_match_len is not None |
| 108 | + ): |
108 | 109 | if ( |
109 | | - sampling_metadata.dry_base is not None |
110 | | - and sampling_metadata.dry_allowed_length is not None |
111 | | - and sampling_metadata.dry_sequence_breaker_ids is not None |
112 | | - and sampling_metadata.dry_ranges is not None |
113 | | - and sampling_metadata.dry_max_ngram is not None |
114 | | - and sampling_metadata.dry_max_occurrences is not None |
115 | | - and sampling_metadata.dry_early_exit_match_len is not None |
| 110 | + sampling_metadata.token_history_ids_cpu is not None |
| 111 | + and sampling_metadata.token_history_lens_cpu is not None |
| 112 | + and sampling_metadata.dry_multiplier_cpu is not None |
| 113 | + and sampling_metadata.dry_allowed_length_cpu is not None |
| 114 | + and sampling_metadata.dry_sequence_breaker_ids_cpu is not None |
| 115 | + and sampling_metadata.dry_ranges_cpu is not None |
| 116 | + and sampling_metadata.dry_max_ngram_cpu is not None |
| 117 | + and sampling_metadata.dry_max_occurrences_cpu is not None |
| 118 | + and sampling_metadata.dry_early_exit_match_len_cpu is not None |
116 | 119 | ): |
117 | | - logits = apply_all_dry( |
118 | | - logits, |
119 | | - sampling_metadata.prompt_token_ids, |
120 | | - output_tokens_t, |
121 | | - sampling_metadata.dry_multiplier, |
122 | | - sampling_metadata.dry_base, |
123 | | - sampling_metadata.dry_allowed_length, |
124 | | - sampling_metadata.dry_sequence_breaker_ids, |
125 | | - sampling_metadata.dry_ranges, |
126 | | - sampling_metadata.dry_max_ngram, |
127 | | - sampling_metadata.dry_max_occurrences, |
128 | | - sampling_metadata.dry_early_exit_match_len, |
| 120 | + row_indexes_cpu, token_indexes_cpu, match_lens_cpu = torch.ops._C.dry_scan_penalties( |
| 121 | + sampling_metadata.token_history_ids_cpu, |
| 122 | + sampling_metadata.token_history_lens_cpu, |
| 123 | + sampling_metadata.dry_multiplier_cpu, |
| 124 | + sampling_metadata.dry_allowed_length_cpu, |
| 125 | + sampling_metadata.dry_sequence_breaker_ids_cpu, |
| 126 | + sampling_metadata.dry_ranges_cpu, |
| 127 | + sampling_metadata.dry_max_ngram_cpu, |
| 128 | + sampling_metadata.dry_max_occurrences_cpu, |
| 129 | + sampling_metadata.dry_early_exit_match_len_cpu, |
| 130 | + logits.size(-1), |
| 131 | + ) |
| 132 | + if row_indexes_cpu.numel(): |
| 133 | + row_indexes_gpu = row_indexes_cpu.to(device=logits.device, non_blocking=True) |
| 134 | + token_indexes_gpu = token_indexes_cpu.to(device=logits.device, non_blocking=True) |
| 135 | + match_lens_gpu = match_lens_cpu.to(device=logits.device, dtype=logits.dtype, non_blocking=True) |
| 136 | + allowed_lengths_t = sampling_metadata.dry_allowed_length[row_indexes_gpu].to(logits.dtype) |
| 137 | + scales = sampling_metadata.dry_base[row_indexes_gpu] ** (match_lens_gpu - allowed_lengths_t) |
| 138 | + logits[row_indexes_gpu, token_indexes_gpu] -= ( |
| 139 | + sampling_metadata.dry_multiplier[row_indexes_gpu] * scales |
| 140 | + ) |
| 141 | + return logits |
| 142 | + |
| 143 | + row_indexes: list[int] = [] |
| 144 | + token_indexes: list[int] = [] |
| 145 | + match_lens: list[int] = [] |
| 146 | + |
| 147 | + for irow_t in sampling_metadata.dry_multiplier.nonzero(as_tuple=True)[0]: |
| 148 | + irow = irow_t.item() |
| 149 | + persistent_entry = sampling_metadata.persistent_data.setdefault(irow, {}) |
| 150 | + dry_state = persistent_entry.get(DRY_STATE_KEY) |
| 151 | + breaker_ids = ( |
| 152 | + sampling_metadata.dry_sequence_breaker_ids[irow] |
| 153 | + .masked_select(sampling_metadata.dry_sequence_breaker_ids[irow] < logits.size(-1)) |
| 154 | + .tolist() |
| 155 | + ) |
| 156 | + expected_len = len(sampling_metadata.output_token_ids[irow]) |
| 157 | + if sampling_metadata.prompt_token_ids is not None: |
| 158 | + expected_len += (sampling_metadata.prompt_token_ids[irow] < logits.size(-1)).sum().item() |
| 159 | + if ( |
| 160 | + not isinstance(dry_state, DryRequestState) |
| 161 | + or dry_state.breaker_ids != frozenset(breaker_ids) |
| 162 | + or len(dry_state.history) != expected_len |
| 163 | + ): |
| 164 | + dry_state = _get_or_rebuild_dry_state( |
| 165 | + persistent_entry, |
| 166 | + None |
| 167 | + if sampling_metadata.prompt_token_ids is None |
| 168 | + else sampling_metadata.prompt_token_ids[irow] |
| 169 | + .masked_select(sampling_metadata.prompt_token_ids[irow] < logits.size(-1)) |
| 170 | + .tolist(), |
| 171 | + sampling_metadata.output_token_ids[irow], |
| 172 | + breaker_ids, |
| 173 | + ) |
| 174 | + penalties = _compute_dry_penalties( |
| 175 | + dry_state, |
| 176 | + allowed_length=sampling_metadata.dry_allowed_length[irow].item(), |
| 177 | + range_limit=sampling_metadata.dry_ranges[irow].item(), |
| 178 | + max_ngram=sampling_metadata.dry_max_ngram[irow].item(), |
| 179 | + max_occurrences=sampling_metadata.dry_max_occurrences[irow].item(), |
| 180 | + early_exit_match_len=sampling_metadata.dry_early_exit_match_len[irow].item(), |
129 | 181 | ) |
| 182 | + for token_id, match_len in penalties.items(): |
| 183 | + row_indexes.append(irow) |
| 184 | + token_indexes.append(token_id) |
| 185 | + match_lens.append(match_len) |
| 186 | + |
| 187 | + if row_indexes: |
| 188 | + row_indexes_t = torch.tensor(row_indexes, device=logits.device, dtype=torch.long) |
| 189 | + token_indexes_t = torch.tensor(token_indexes, device=logits.device, dtype=torch.long) |
| 190 | + match_lens_t = torch.tensor(match_lens, device=logits.device, dtype=logits.dtype) |
| 191 | + allowed_lengths_t = sampling_metadata.dry_allowed_length[row_indexes_t].to(logits.dtype) |
| 192 | + scales = sampling_metadata.dry_base[row_indexes_t] ** (match_lens_t - allowed_lengths_t) |
| 193 | + logits[row_indexes_t, token_indexes_t] -= sampling_metadata.dry_multiplier[row_indexes_t] * scales |
| 194 | + return logits |
130 | 195 | return logits |
131 | 196 |
|
132 | 197 | def apply_no_repeat_ngram( |
|
0 commit comments