Commit 0d69f0f
committed
fix(train/data): default create_empty_sample to bfloat16 to match training dtype
When vLLM hidden-state extraction times out during training, the training
loop substitutes an empty sample via `create_empty_sample()`. The original
implementation built `torch.empty(0, ...)` tensors with no `dtype` argument,
so PyTorch fell back to `float32`.
Downstream EAGLE-3 layers (`fc`, `verifier_lm_head`) load `bfloat16` weights
when training a bf16 verifier. The first time the empty sample reached one
of those layers we got an explicit dtype mismatch crash, taking the whole
job down at a random late-epoch step.
This patch:
- Adds a `dtype: torch.dtype = torch.bfloat16` keyword argument (covers the
common case of training against a bf16 verifier).
- Threads it through the `hidden_states` and `verifier_last_hidden_states`
tensors so the empty placeholders match downstream weight dtype.
- Also pins `input_ids` to `torch.long` (it was previously default float).
Reproducer: train an EAGLE-3 drafter against any bf16 verifier with
`extract_hidden_states` + `ExampleHiddenStatesConnector`. Sufficient vLLM
extraction time-outs eventually surface the empty-sample path and the run
crashes with `RuntimeError: ... expected Float, got BFloat16`. With this
default, the empty sample flows through the bf16 layers cleanly.
Callers that train against a non-bf16 verifier can override the dtype
explicitly.1 parent 4f80f5d commit 0d69f0f
1 file changed
Lines changed: 8 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
64 | 64 | | |
65 | 65 | | |
66 | 66 | | |
67 | | - | |
| 67 | + | |
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
| |||
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
76 | 80 | | |
77 | 81 | | |
78 | | - | |
79 | | - | |
80 | | - | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
81 | 85 | | |
82 | 86 | | |
83 | 87 | | |
| |||
0 commit comments