Skip to content

Commit b5d3e11

Browse files
authored
[sampler] fix mixed penalties in batch with async scheduling (#1594)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 5bce45e commit b5d3e11

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

aphrodite/v1/sample/ops/penalties.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ def apply_all_penalties(
1818
"""
1919
_, vocab_size = logits.shape
2020
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
21+
22+
# In the async scheduling case, rows that won't have penalties applied may contain
23+
# -1 placeholder token ids. We must replace these with valid token ids so that the
24+
# scatter done in apply_penalties is valid.
25+
# NOTE: The penalties implementation is currently quite inefficient and
26+
# will be reworked anyhow.
27+
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
28+
2129
return apply_penalties(
2230
logits,
2331
prompt_token_ids,

0 commit comments

Comments
 (0)