Commit cf70e89
Fix compute_hidden_states_hf.py: handle BatchEncoding from apply_chat_template (#1225)
## Summary
- `apply_chat_template(..., return_tensors="pt")` returns a
`BatchEncoding` in transformers 4.46+, which no longer subclasses `dict`
- The old guard `isinstance(tokenized, dict)` evaluates to `False` for
`BatchEncoding`, so `input_ids` was set to the whole `BatchEncoding`
object
- Calling `.shape[1]` on a `BatchEncoding` triggers
`__getattr__("shape")` → `AttributeError`
- Fix: check `isinstance(tokenized, torch.Tensor)` instead, which
correctly handles both old transformers (plain tensor) and new
transformers (BatchEncoding)
This is causing `test_collect_hidden_states` to fail in the speculative
decoding CI for all open PRs (#1207, #1210, #1221).
## Test plan
- [ ] `torch-pr (speculative_decoding, 26.01)` CI passes
- [ ] Verify fix handles both `torch.Tensor` return (old transformers)
and `BatchEncoding` return (new transformers 4.46+)
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>1 parent f1a2260 commit cf70e89
1 file changed
Lines changed: 6 additions & 4 deletions
File tree
Lines changed: 6 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
206 | 206 | | |
207 | 207 | | |
208 | 208 | | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
213 | 215 | | |
214 | 216 | | |
215 | 217 | | |
| |||
0 commit comments