Skip to content

Commit ff0df2b

Browse files
yeyu-nvidiaclaude
authored andcommitted
Fix test_collect_hidden_states: use synthetic short conversations (#1234)
## Summary - `test_collect_hidden_states` was using real daring-anteater conversations (typically 1000+ tokens) but the tiny test model has `max_position_embeddings=32`. Both sampled conversations exceeded the default `--max-seq-len 3072` filter, producing zero `.pt` files and failing the assertion. - Added a `tiny_conversations_path` fixture with synthetic short single-turn conversations that tokenize within `max_position_embeddings=32`. - Changed `test_collect_hidden_states` to use this fixture with `--max-seq-len 32`. - Added a `None` guard for `tokenizer.chat_template.replace(...)` to avoid `AttributeError` when the tokenizer has no chat template. ## Test plan - [ ] `pytest tests/examples/speculative_decoding/test_eagle_offline_ptq.py::test_collect_hidden_states` passes - [ ] CI `speculative_decoding` job passes 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Resolved compatibility issues when tokenizers do not have a chat template configuration by adding proper error handling. * Standardized tokenization input extraction logic across different transformer library versions for consistent behavior. * **Tests** * Enhanced test infrastructure with new conversation data fixtures and improved sequence length validation for speculative decoding examples. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 77902d5 commit ff0df2b

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def keep_conversation(entry):
142142
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
143143
if tokenizer.pad_token is None:
144144
tokenizer.pad_token = tokenizer.eos_token
145-
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
145+
if tokenizer.chat_template is not None:
146+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
146147

147148
output_dir = args.output_dir
148149
output_dir.mkdir(parents=True, exist_ok=True)

tests/examples/speculative_decoding/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,38 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
17+
1618
import pytest
1719
import yaml
1820
from _test_utils.examples.run_command import run_example_command
1921

2022

23+
@pytest.fixture(scope="session")
24+
def tiny_conversations_path(tmp_path_factory):
25+
"""Tiny JSONL with short synthetic conversations for compute_hidden_states_hf tests.
26+
27+
Uses minimal single-turn conversations so that tokenized lengths stay well
28+
within the tiny test model's max_position_embeddings (32) even after chat
29+
template formatting.
30+
"""
31+
tmp_dir = tmp_path_factory.mktemp("tiny_convs")
32+
output_file = tmp_dir / "train.jsonl"
33+
conversations = [
34+
{
35+
"conversation_id": f"test-{i}",
36+
"conversations": [
37+
{"role": "user", "content": "What is 2 plus 2?"},
38+
{"role": "assistant", "content": "4"},
39+
],
40+
}
41+
for i in range(5)
42+
]
43+
with open(output_file, "w") as f:
44+
f.writelines(json.dumps(conv) + "\n" for conv in conversations)
45+
return output_file
46+
47+
2148
@pytest.fixture(scope="session", autouse=True)
2249
def tiny_daring_anteater_path(tmp_path_factory):
2350
tmp_dir = tmp_path_factory.mktemp("daring_anteater")

tests/examples/speculative_decoding/test_eagle_offline_ptq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def offline_ptq_dirs(tmp_path_factory):
5555
}
5656

5757

58-
def test_collect_hidden_states(tiny_llama_path, tiny_daring_anteater_path, offline_ptq_dirs):
58+
def test_collect_hidden_states(tiny_llama_path, tiny_conversations_path, offline_ptq_dirs):
5959
"""Stage 1: generate .pt hidden state files from the base model."""
6060
run_example_command(
6161
[
@@ -64,11 +64,13 @@ def test_collect_hidden_states(tiny_llama_path, tiny_daring_anteater_path, offli
6464
"--model",
6565
tiny_llama_path,
6666
"--input-data",
67-
str(tiny_daring_anteater_path),
67+
str(tiny_conversations_path),
6868
"--output-dir",
6969
str(offline_ptq_dirs["hidden_states"]),
7070
"--debug-max-num-conversations",
7171
"2",
72+
"--max-seq-len",
73+
"32",
7274
],
7375
"speculative_decoding",
7476
)

0 commit comments

Comments
 (0)