Skip to content

Commit 0d69f0f

Browse files
committed
fix(train/data): default create_empty_sample to bfloat16 to match training dtype
When vLLM hidden-state extraction times out during training, the training loop substitutes an empty sample via `create_empty_sample()`. The original implementation built `torch.empty(0, ...)` tensors with no `dtype` argument, so PyTorch fell back to `float32`. Downstream EAGLE-3 layers (`fc`, `verifier_lm_head`) load `bfloat16` weights when training a bf16 verifier. The first time the empty sample reached one of those layers we got an explicit dtype mismatch crash, taking the whole job down at a random late-epoch step. This patch: - Adds a `dtype: torch.dtype = torch.bfloat16` keyword argument (covers the common case of training against a bf16 verifier). - Threads it through the `hidden_states` and `verifier_last_hidden_states` tensors so the empty placeholders match downstream weight dtype. - Also pins `input_ids` to `torch.long` (it was previously default float). Reproducer: train an EAGLE-3 drafter against any bf16 verifier with `extract_hidden_states` + `ExampleHiddenStatesConnector`. Sufficient vLLM extraction time-outs eventually surface the empty-sample path and the run crashes with `RuntimeError: ... expected Float, got BFloat16`. With this default, the empty sample flows through the bf16 layers cleanly. Callers that train against a non-bf16 verifier can override the dtype explicitly.
1 parent 4f80f5d commit 0d69f0f

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

src/speculators/train/data.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0):
6464
StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]]
6565

6666

67-
def create_empty_sample(hidden_size: int):
67+
def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16):
6868
# data structure: {
6969
# "hidden_states": [seq_len, 3 * hidden_size],
7070
# "input_ids": [seq_len],
@@ -73,11 +73,15 @@ def create_empty_sample(hidden_size: int):
7373
# "lengths": [1],
7474
# "position_ids": [seq_len],
7575
# }
76+
# Default dtype is bfloat16 to match the hidden_states dtype used downstream.
77+
# When this fallback is used (e.g. vLLM hidden-state extraction times out and
78+
# we substitute an empty sample), the implicit float32 placeholders crashed
79+
# bf16 EAGLE-3 layers (fc, verifier_lm_head) with a dtype mismatch.
7680

7781
return {
78-
"hidden_states": torch.empty(0, 3 * hidden_size),
79-
"input_ids": torch.empty(0),
80-
"verifier_last_hidden_states": torch.empty(0, hidden_size),
82+
"hidden_states": torch.empty(0, 3 * hidden_size, dtype=dtype),
83+
"input_ids": torch.empty(0, dtype=torch.long),
84+
"verifier_last_hidden_states": torch.empty(0, hidden_size, dtype=dtype),
8185
"loss_mask": torch.empty(0),
8286
"lengths": torch.tensor([0], dtype=torch.long),
8387
"position_ids": torch.arange(0, dtype=torch.long),

0 commit comments

Comments
 (0)