We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5bce45e commit b5d3e11Copy full SHA for b5d3e11
1 file changed
aphrodite/v1/sample/ops/penalties.py
@@ -18,6 +18,14 @@ def apply_all_penalties(
18
"""
19
_, vocab_size = logits.shape
20
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
21
+
22
+ # In the async scheduling case, rows that won't have penalties applied may contain
23
+ # -1 placeholder token ids. We must replace these with valid token ids so that the
24
+ # scatter done in apply_penalties is valid.
25
+ # NOTE: The penalties implementation is currently quite inefficient and
26
+ # will be reworked anyhow.
27
+ output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
28
29
return apply_penalties(
30
logits,
31
prompt_token_ids,
0 commit comments