diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py index 1060df4562..7a76ec4d8f 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py @@ -142,7 +142,8 @@ def keep_conversation(entry): tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if tokenizer.chat_template is not None: + tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") output_dir = args.output_dir output_dir.mkdir(parents=True, exist_ok=True) @@ -206,10 +207,12 @@ async def submit_generates(): continue # Tokenize and check length - tokenized = tokenizer.apply_chat_template( - conversations, return_tensors="pt", add_generation_template=False - ) - input_ids = tokenized["input_ids"] if isinstance(tokenized, dict) else tokenized + # return_dict=True ensures BatchEncoding is returned on all transformers + # versions: in <5.0 the default is False (returns raw tensor), in 5.0+ + # the default changed to True (returns BatchEncoding). + input_ids = tokenizer.apply_chat_template( + conversations, return_tensors="pt", return_dict=True, add_generation_template=False + )["input_ids"] num_input_tokens = input_ids.shape[1] if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: num_skipped_too_long += 1 diff --git a/tests/examples/speculative_decoding/conftest.py b/tests/examples/speculative_decoding/conftest.py index 1750b9d454..4ea2bc1f29 100644 --- a/tests/examples/speculative_decoding/conftest.py +++ b/tests/examples/speculative_decoding/conftest.py @@ -13,11 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + import pytest import yaml from _test_utils.examples.run_command import run_example_command +@pytest.fixture(scope="session") +def tiny_conversations_path(tmp_path_factory): + """Tiny JSONL with short synthetic conversations for compute_hidden_states_hf tests. + + Uses minimal single-turn conversations so that tokenized lengths stay well + within the tiny test model's max_position_embeddings (32) even after chat + template formatting. + """ + tmp_dir = tmp_path_factory.mktemp("tiny_convs") + output_file = tmp_dir / "train.jsonl" + conversations = [ + { + "conversation_id": f"test-{i}", + "conversations": [ + {"role": "user", "content": "What is 2 plus 2?"}, + {"role": "assistant", "content": "4"}, + ], + } + for i in range(5) + ] + with open(output_file, "w") as f: + f.writelines(json.dumps(conv) + "\n" for conv in conversations) + return output_file + + @pytest.fixture(scope="session", autouse=True) def tiny_daring_anteater_path(tmp_path_factory): tmp_dir = tmp_path_factory.mktemp("daring_anteater") diff --git a/tests/examples/speculative_decoding/test_eagle_offline_ptq.py b/tests/examples/speculative_decoding/test_eagle_offline_ptq.py index baba7fb2a4..034a48189c 100644 --- a/tests/examples/speculative_decoding/test_eagle_offline_ptq.py +++ b/tests/examples/speculative_decoding/test_eagle_offline_ptq.py @@ -55,7 +55,7 @@ def offline_ptq_dirs(tmp_path_factory): } -def test_collect_hidden_states(tiny_llama_path, tiny_daring_anteater_path, offline_ptq_dirs): +def test_collect_hidden_states(tiny_llama_path, tiny_conversations_path, offline_ptq_dirs): """Stage 1: generate .pt hidden state files from the base model.""" run_example_command( [ @@ -64,11 +64,13 @@ def test_collect_hidden_states(tiny_llama_path, tiny_daring_anteater_path, offli "--model", tiny_llama_path, "--input-data", - str(tiny_daring_anteater_path), + str(tiny_conversations_path), "--output-dir", str(offline_ptq_dirs["hidden_states"]), "--debug-max-num-conversations", "2", + "--max-seq-len", + "32", ], "speculative_decoding", )