Skip to content

Commit cf70e89

Browse files
yeyu-nvidiaclaude
authored andcommitted
Fix compute_hidden_states_hf.py: handle BatchEncoding from apply_chat_template (#1225)
## Summary - `apply_chat_template(..., return_tensors="pt")` returns a `BatchEncoding` in transformers 4.46+, which no longer subclasses `dict` - The old guard `isinstance(tokenized, dict)` evaluates to `False` for `BatchEncoding`, so `input_ids` was set to the whole `BatchEncoding` object - Calling `.shape[1]` on a `BatchEncoding` triggers `__getattr__("shape")` → `AttributeError` - Fix: check `isinstance(tokenized, torch.Tensor)` instead, which correctly handles both old transformers (plain tensor) and new transformers (BatchEncoding) This is causing `test_collect_hidden_states` to fail in the speculative decoding CI for all open PRs (#1207, #1210, #1221). ## Test plan - [ ] `torch-pr (speculative_decoding, 26.01)` CI passes - [ ] Verify fix handles both `torch.Tensor` return (old transformers) and `BatchEncoding` return (new transformers 4.46+) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent f1a2260 commit cf70e89

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,12 @@ async def submit_generates():
206206
continue
207207

208208
# Tokenize and check length
209-
tokenized = tokenizer.apply_chat_template(
210-
conversations, return_tensors="pt", add_generation_template=False
211-
)
212-
input_ids = tokenized["input_ids"] if isinstance(tokenized, dict) else tokenized
209+
# return_dict=True ensures BatchEncoding is returned on all transformers
210+
# versions: in <5.0 the default is False (returns raw tensor), in 5.0+
211+
# the default changed to True (returns BatchEncoding).
212+
input_ids = tokenizer.apply_chat_template(
213+
conversations, return_tensors="pt", return_dict=True, add_generation_template=False
214+
)["input_ids"]
213215
num_input_tokens = input_ids.shape[1]
214216
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
215217
num_skipped_too_long += 1

0 commit comments

Comments
 (0)