fix(train/data): default create_empty_sample to bfloat16 to match training dtype#527
fix(train/data): default create_empty_sample to bfloat16 to match training dtype#527coolthor wants to merge 2 commits into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
WalkthroughThe PR adds an optional ChangesEmpty Sample Dtype Control
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…ining dtype When vLLM hidden-state extraction times out during training, the training loop substitutes an empty sample via `create_empty_sample()`. The original implementation built `torch.empty(0, ...)` tensors with no `dtype` argument, so PyTorch fell back to `float32`. Downstream EAGLE-3 layers (`fc`, `verifier_lm_head`) load `bfloat16` weights when training a bf16 verifier. The first time the empty sample reached one of those layers we got an explicit dtype mismatch crash, taking the whole job down at a random late-epoch step. This patch: - Adds a `dtype: torch.dtype = torch.bfloat16` keyword argument (covers the common case of training against a bf16 verifier). - Threads it through the `hidden_states` and `verifier_last_hidden_states` tensors so the empty placeholders match downstream weight dtype. - Also pins `input_ids` to `torch.long` (it was previously default float). Reproducer: train an EAGLE-3 drafter against any bf16 verifier with `extract_hidden_states` + `ExampleHiddenStatesConnector`. Sufficient vLLM extraction time-outs eventually surface the empty-sample path and the run crashes with `RuntimeError: ... expected Float, got BFloat16`. With this default, the empty sample flows through the bf16 layers cleanly. Callers that train against a non-bf16 verifier can override the dtype explicitly. Signed-off-by: Thor Lin <coolthor@gmail.com>
0d69f0f to
1f8e682
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/speculators/train/data.py (1)
434-447:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftDtype inconsistency:
create_collate_fncannot overridecreate_empty_sampledtype, but dataset defaults differ.The
create_empty_samplecall at line 446 uses the defaulttorch.bfloat16, butcreate_collate_fnhas no way to pass a different dtype. This creates a critical inconsistency:
BaseDatasetandArrowDatasetdefaulthidden_states_dtype=torch.float(lines 117, 184)create_empty_sampledefaults todtype=torch.bfloat16(line 67)- The dataset's dtype conversion (line 145) only applies to samples from
__getitem__, NOT to empty samples created in the collate functionWhen all samples in a batch fail (e.g., vLLM timeouts), the collate function creates a bf16 empty sample that won't match the dataset's configured dtype, potentially causing crashes for non-bf16 training.
The PR objectives state "Callers can override the default dtype for non-bf16 verifiers," but this caller cannot.
Recommended fixes:
- Option 1 (preferred): Add a
dtypeparameter tocreate_collate_fnand thread it through tocreate_empty_sample.- Option 2: Change dataset defaults to
hidden_states_dtype=torch.bfloat16to match the new empty sample default.- Option 3: Make the collate function introspect the batch to determine the correct dtype.
💡 Proposed fix: Thread dtype through create_collate_fn
def create_collate_fn( max_len: int, hidden_size: int, + dtype: torch.dtype = torch.bfloat16, preprocess: Callable[[BatchType], BatchType] | None = None, ): def collate_fn(batch: list[BatchType | None]) -> BatchType: # Apply per-sample preprocessing and filter failed samples batch = [preprocess(b) if preprocess else b for b in batch if b is not None] if not batch: # Create empty sample which then gets padded to full # batch size if no valid samples are found - batch = [create_empty_sample(hidden_size)] + batch = [create_empty_sample(hidden_size, dtype=dtype)] collated_data = {}Then callers of
create_collate_fnshould pass the dataset'shidden_states_dtype.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/speculators/train/data.py` around lines 434 - 447, create_collate_fn can emit an empty sample with the wrong dtype because it always calls create_empty_sample() which defaults to torch.bfloat16; add a dtype parameter to create_collate_fn (e.g., dtype: torch.dtype | None) and pass it through to create_empty_sample(dtype=...) when creating the empty sample, then update all call sites (where collate_fn is constructed) to pass the dataset's hidden_states_dtype (from BaseDataset / ArrowDataset) so empty samples match the dataset conversion; ensure the collate_fn still uses existing preprocess behavior and handles None appropriately.
🧹 Nitpick comments (1)
src/speculators/train/data.py (1)
67-88: ⚡ Quick winGood explicit dtype control, but consider making
loss_maskdtype explicit too.The explicit
dtypeparameter and its application tohidden_statesandverifier_last_hidden_statesis a solid improvement. The explicittorch.longforinput_idsis also good practice.However,
loss_mask(line 85) still uses implicit dtype defaulting to float32. While this is likely correct, being explicit would improve clarity and prevent future confusion.📝 Optional: Make loss_mask dtype explicit
- "loss_mask": torch.empty(0), + "loss_mask": torch.empty(0, dtype=torch.float32),🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/speculators/train/data.py` around lines 67 - 88, create_empty_sample currently leaves loss_mask as torch.empty(0) which defaults to float32 implicitly; change the loss_mask creation to explicitly specify its dtype (e.g., torch.empty(0, dtype=torch.float32)) so the function clearly documents the intended type and avoids future ambiguity—update the "loss_mask" entry in create_empty_sample accordingly (alongside the existing explicit dtypes for hidden_states, verifier_last_hidden_states, input_ids, lengths, and position_ids).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@src/speculators/train/data.py`:
- Around line 434-447: create_collate_fn can emit an empty sample with the wrong
dtype because it always calls create_empty_sample() which defaults to
torch.bfloat16; add a dtype parameter to create_collate_fn (e.g., dtype:
torch.dtype | None) and pass it through to create_empty_sample(dtype=...) when
creating the empty sample, then update all call sites (where collate_fn is
constructed) to pass the dataset's hidden_states_dtype (from BaseDataset /
ArrowDataset) so empty samples match the dataset conversion; ensure the
collate_fn still uses existing preprocess behavior and handles None
appropriately.
---
Nitpick comments:
In `@src/speculators/train/data.py`:
- Around line 67-88: create_empty_sample currently leaves loss_mask as
torch.empty(0) which defaults to float32 implicitly; change the loss_mask
creation to explicitly specify its dtype (e.g., torch.empty(0,
dtype=torch.float32)) so the function clearly documents the intended type and
avoids future ambiguity—update the "loss_mask" entry in create_empty_sample
accordingly (alongside the existing explicit dtypes for hidden_states,
verifier_last_hidden_states, input_ids, lengths, and position_ids).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8ad4a5db-24cf-45c7-af06-0d50c1053a4a
📒 Files selected for processing (1)
src/speculators/train/data.py
…ty_sample Address CodeRabbit review feedback on PR vllm-project#527: the caller of create_empty_sample — create_collate_fn — had no way to pass a dtype through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would still hit a dtype mismatch when the empty-sample fallback fires. Add a dtype keyword argument to create_collate_fn defaulting to torch.bfloat16, and thread it into the create_empty_sample call so callers can match the empty placeholder to their training precision. The bf16 default keeps existing bf16 training paths unchanged while unblocking non-bf16 callers cleanly. Signed-off-by: Thor Lin <coolthor@gmail.com>
|
Thanks @coderabbitai — good catch on |
|
Tip For best results, initiate chat on the files or code changes.
(╯°□°)╯ ✨🐰 If you found this review helpful, would you consider giving us a shout-out on X? Thank you for using CodeRabbit! |
…ty_sample Address CodeRabbit review feedback on PR vllm-project#527: the caller of create_empty_sample — create_collate_fn — had no way to pass a dtype through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would still hit a dtype mismatch when the empty-sample fallback fires. Add a dtype keyword argument to create_collate_fn defaulting to torch.bfloat16, and thread it into the create_empty_sample call so callers can match the empty placeholder to their training precision. The bf16 default keeps existing bf16 training paths unchanged while unblocking non-bf16 callers cleanly. Signed-off-by: coolthor <coolthor@gmail.com>
e8b6576 to
5f4b33c
Compare
…, add bilingual video - TLDR rewrite (zh-TW + en): "好消息 — Part 28 留下的壞消息,這次處理掉了" / "Good news — the Part 28 bottleneck is dealt with". Three paragraphs (setup → result → side bug) replacing the previous spec-sheet style. - Stale PR-status text removed: all "PR in preparation / 準備中 / 還沒開" mentions (FAQ, TLDR, body, references section) now link directly to the live upstream PR vllm-project/speculators#527 in both languages. - Novelty block softened per Codex review: dropped "most-downloaded / high-leverage point to target" framing, added explicit caveat that Round 1's headline numbers are a single-config improvement against a re-used Part 28 baseline, with a paired same-prompt rerun + no-abliteration control queued for Round 2. - HF README on coolthor/Huihui-...-eagle3-draft now points back at Part 28/29/30 in both languages (bidirectional traffic; pushed via huggingface_hub). - Remotion explainer video added in both languages: mechanism-first 3-scene (Part 28 collapse curve → Round 1 retrain pipeline → flattened acceptance + 2.0x throughput). EN video uses -en.mp4 suffix per Part 28 convention. 🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
|
||
|
|
||
| def create_empty_sample(hidden_size: int): | ||
| def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16): |
There was a problem hiding this comment.
Instead of setting the default to bfloat16, could we infer dtype from the model? It's not super intuitive to override the default.
There was a problem hiding this comment.
We can actually set this explicitly when creating the collate function. It just isn't being piped in yet.
This is where create_collate_fn is called:
Line 99 in 4f80f5d
We explicitly set the hidden state dtype in scripts/train.py here:
Line 319 in 4f80f5d
This is then set as an attribute on the dataset object. So if you either get hidden_states_dtype from the dataset object passed into setup_dataloader or pass it into the fn directly, you can then pipe it into the create_collate_fn call.
fynnsu
left a comment
There was a problem hiding this comment.
One small comment. Otherwise looks good!
|
|
||
|
|
||
| def create_empty_sample(hidden_size: int): | ||
| def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16): |
There was a problem hiding this comment.
We can actually set this explicitly when creating the collate function. It just isn't being piped in yet.
This is where create_collate_fn is called:
Line 99 in 4f80f5d
We explicitly set the hidden state dtype in scripts/train.py here:
Line 319 in 4f80f5d
This is then set as an attribute on the dataset object. So if you either get hidden_states_dtype from the dataset object passed into setup_dataloader or pass it into the fn directly, you can then pipe it into the create_collate_fn call.
Problem
When
extract_hidden_states+ExampleHiddenStatesConnectoris used to train an EAGLE-3 drafter against a bf16 verifier, intermittent vLLM extraction time-outs cause the training loop to substitute an empty sample built bycreate_empty_sample().In the original implementation:
torch.emptywithout an explicitdtypedefaults tofloat32. Downstream EAGLE-3 layers (fc,verifier_lm_head) holdbfloat16weights when training a bf16 verifier, and the first time the empty sample reaches one of those layers PyTorch raises:The whole training run dies at a random late-epoch step.
Repro
Fix
dtype: torch.dtype = torch.bfloat16keyword argument tocreate_empty_sample()(covers the common case of training against a bf16 verifier).hidden_statesandverifier_last_hidden_statesplaceholders so they match downstream weight dtype.input_idstotorch.long(it previously fell back to float, which would surface separately for any consumer that doesembedding(input_ids)).Callers that train against a non-bf16 verifier (fp32, fp16, etc.) can override the default explicitly.
Tested with
coolthor/Huihui-gemma-4-26B-A4B-it-abliterated-FP8-Dynamicverifier + RedHatAI EAGLE-3 drafter, 50k Magpie samples, 1 epoch on a single DGX Spark GB10.Writeup with full numbers: https://ai-muninn.com/en/blog/dgx-spark-eagle3-finetune-abliterated-round1
🤖 Generated with Claude Code