Skip to content

Commit ff07459

Browse files
authored
[https://nvbugs/5823135][fix] Fix min_tokens not respected when prompt is long (#12166)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent b72ee4f commit ff07459

2 files changed

Lines changed: 45 additions & 2 deletions

File tree

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3871,13 +3871,18 @@ def _apply_min_length_penalty(
38713871
Returns:
38723872
The logits with min length penalty applied
38733873
"""
3874-
if any(r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0] for r in requests):
3874+
if any(
3875+
r.py_min_length and (r.max_beam_num_tokens - r.py_orig_prompt_len) < r.py_min_length[0]
3876+
for r in requests
3877+
):
38753878
current_offset = 0
38763879
for index, r in enumerate(requests):
38773880
if r.py_min_length:
38783881
for beam_idx in range(num_beams[index]):
38793882
for step in range(num_steps[index]):
3880-
if r.get_num_tokens(beam_idx) + step < r.py_min_length[0]:
3883+
if (
3884+
r.get_num_tokens(beam_idx) - r.py_orig_prompt_len
3885+
) + step < r.py_min_length[0]:
38813886
# NOTE(jthomson04): We can NOT just assign logits[...] = float("-inf").
38823887
# This introduces a pageable HtoD transfer, which wreaks havoc on TPOT (up to ~20%)
38833888
# Instead, we create a little tensor on device, then assign to that.

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,44 @@ def test_min_tokens(use_speculative: bool):
10761076
assert len(res.outputs[0].token_ids) == output_len
10771077

10781078

1079+
@pytest.mark.part0
1080+
def test_min_tokens_long_prompt():
1081+
"""Check min_tokens is respected when prompt is longer than min_tokens.
1082+
1083+
Regression test for NVBug 5823135: _apply_min_length_penalty compared
1084+
total token count (prompt + generated) against the raw min_tokens value
1085+
instead of comparing generated token count only. When prompt_len >=
1086+
min_tokens the EOS suppression was never activated, allowing early
1087+
termination.
1088+
"""
1089+
min_tok = 50
1090+
max_tok = 100
1091+
# Prompt long enough so that prompt_len > min_tok. "Hello " tokenises
1092+
# to ~1-2 tokens with most tokenizers, so 200 repetitions ≈ 200-400
1093+
# tokens >> min_tok.
1094+
long_prompt = "Hello " * 200
1095+
1096+
llm = LLM(
1097+
model=llama_model_path,
1098+
max_batch_size=2,
1099+
kv_cache_config=global_kvcache_config,
1100+
max_num_tokens=2048,
1101+
)
1102+
1103+
sampling_params = SamplingParams(
1104+
max_tokens=max_tok,
1105+
min_tokens=min_tok,
1106+
temperature=1,
1107+
)
1108+
res = llm.generate(long_prompt, sampling_params=sampling_params)
1109+
1110+
assert len(res.outputs) == 1
1111+
generated_len = len(res.outputs[0].token_ids)
1112+
assert generated_len >= min_tok, (
1113+
f"Generated only {generated_len} tokens with min_tokens={min_tok} "
1114+
f"and a long prompt. Bug 5823135 regression.")
1115+
1116+
10791117
@skip_ray
10801118
@pytest.mark.parametrize(
10811119
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",

0 commit comments

Comments
 (0)