From 1f8e6828f392078c33ed959a5145291822cc8188 Mon Sep 17 00:00:00 2001 From: coolthor Date: Sun, 17 May 2026 00:36:13 +0800 Subject: [PATCH 1/2] fix(train/data): default create_empty_sample to bfloat16 to match training dtype When vLLM hidden-state extraction times out during training, the training loop substitutes an empty sample via `create_empty_sample()`. The original implementation built `torch.empty(0, ...)` tensors with no `dtype` argument, so PyTorch fell back to `float32`. Downstream EAGLE-3 layers (`fc`, `verifier_lm_head`) load `bfloat16` weights when training a bf16 verifier. The first time the empty sample reached one of those layers we got an explicit dtype mismatch crash, taking the whole job down at a random late-epoch step. This patch: - Adds a `dtype: torch.dtype = torch.bfloat16` keyword argument (covers the common case of training against a bf16 verifier). - Threads it through the `hidden_states` and `verifier_last_hidden_states` tensors so the empty placeholders match downstream weight dtype. - Also pins `input_ids` to `torch.long` (it was previously default float). Reproducer: train an EAGLE-3 drafter against any bf16 verifier with `extract_hidden_states` + `ExampleHiddenStatesConnector`. Sufficient vLLM extraction time-outs eventually surface the empty-sample path and the run crashes with `RuntimeError: ... expected Float, got BFloat16`. With this default, the empty sample flows through the bf16 layers cleanly. Callers that train against a non-bf16 verifier can override the dtype explicitly. Signed-off-by: Thor Lin --- src/speculators/train/data.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index b1e05c91..fbeee4fa 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -64,7 +64,7 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0): StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]] -def create_empty_sample(hidden_size: int): +def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16): # data structure: { # "hidden_states": [seq_len, 3 * hidden_size], # "input_ids": [seq_len], @@ -73,11 +73,15 @@ def create_empty_sample(hidden_size: int): # "lengths": [1], # "position_ids": [seq_len], # } + # Default dtype is bfloat16 to match the hidden_states dtype used downstream. + # When this fallback is used (e.g. vLLM hidden-state extraction times out and + # we substitute an empty sample), the implicit float32 placeholders crashed + # bf16 EAGLE-3 layers (fc, verifier_lm_head) with a dtype mismatch. return { - "hidden_states": torch.empty(0, 3 * hidden_size), - "input_ids": torch.empty(0), - "verifier_last_hidden_states": torch.empty(0, hidden_size), + "hidden_states": torch.empty(0, 3 * hidden_size, dtype=dtype), + "input_ids": torch.empty(0, dtype=torch.long), + "verifier_last_hidden_states": torch.empty(0, hidden_size, dtype=dtype), "loss_mask": torch.empty(0), "lengths": torch.tensor([0], dtype=torch.long), "position_ids": torch.arange(0, dtype=torch.long), From 5f4b33cb1298b58454cefc44b7580a3098ce5f7b Mon Sep 17 00:00:00 2001 From: coolthor Date: Sun, 17 May 2026 00:59:32 +0800 Subject: [PATCH 2/2] fix(train/data): thread dtype through create_collate_fn to create_empty_sample MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address CodeRabbit review feedback on PR #527: the caller of create_empty_sample — create_collate_fn — had no way to pass a dtype through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would still hit a dtype mismatch when the empty-sample fallback fires. Add a dtype keyword argument to create_collate_fn defaulting to torch.bfloat16, and thread it into the create_empty_sample call so callers can match the empty placeholder to their training precision. The bf16 default keeps existing bf16 training paths unchanged while unblocking non-bf16 callers cleanly. Signed-off-by: coolthor --- src/speculators/train/data.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index fbeee4fa..5007c222 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -434,6 +434,7 @@ def _get_raw_data(self, index): def create_collate_fn( max_len: int, hidden_size: int, + dtype: torch.dtype = torch.bfloat16, preprocess: Callable[[BatchType], BatchType] | None = None, ): def collate_fn(batch: list[BatchType | None]) -> BatchType: @@ -442,8 +443,11 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType: if not batch: # Create empty sample which then gets padded to full - # batch size if no valid samples are found - batch = [create_empty_sample(hidden_size)] + # batch size if no valid samples are found. + # Match the configured `dtype` so the placeholder doesn't crash + # downstream layers loaded at a different precision (e.g. bf16 + # weights vs fp32 default placeholders). + batch = [create_empty_sample(hidden_size, dtype=dtype)] collated_data = {} for key in batch[0]: # type: ignore[union-attr]