Skip to content

fix(train/data): default create_empty_sample to bfloat16 to match training dtype#527

Open
coolthor wants to merge 2 commits into
vllm-project:mainfrom
coolthor:fix-create-empty-sample-bf16-default
Open

fix(train/data): default create_empty_sample to bfloat16 to match training dtype#527
coolthor wants to merge 2 commits into
vllm-project:mainfrom
coolthor:fix-create-empty-sample-bf16-default

Conversation

@coolthor
Copy link
Copy Markdown

Problem

When extract_hidden_states + ExampleHiddenStatesConnector is used to train an EAGLE-3 drafter against a bf16 verifier, intermittent vLLM extraction time-outs cause the training loop to substitute an empty sample built by create_empty_sample().

In the original implementation:

return {
    "hidden_states": torch.empty(0, 3 * hidden_size),
    "input_ids": torch.empty(0),
    "verifier_last_hidden_states": torch.empty(0, hidden_size),
    ...
}

torch.empty without an explicit dtype defaults to float32. Downstream EAGLE-3 layers (fc, verifier_lm_head) hold bfloat16 weights when training a bf16 verifier, and the first time the empty sample reaches one of those layers PyTorch raises:

RuntimeError: ... expected Float, got BFloat16

The whole training run dies at a random late-epoch step.

Repro

1. Train an EAGLE-3 drafter against any bf16 verifier (e.g. `google/gemma-4-26B-A4B-it`).
2. Use `extract_hidden_states` + `ExampleHiddenStatesConnector` as the data path.
3. Let it run long enough that some vLLM extractions time out (high concurrency or large samples accelerate this).
4. Eventually a step substitutes the empty sample and the run crashes at `self.fc(hidden_states)` with the dtype mismatch above.

Fix

  • Add a dtype: torch.dtype = torch.bfloat16 keyword argument to create_empty_sample() (covers the common case of training against a bf16 verifier).
  • Apply it to the hidden_states and verifier_last_hidden_states placeholders so they match downstream weight dtype.
  • Pin input_ids to torch.long (it previously fell back to float, which would surface separately for any consumer that does embedding(input_ids)).

Callers that train against a non-bf16 verifier (fp32, fp16, etc.) can override the default explicitly.

Tested with

  • coolthor/Huihui-gemma-4-26B-A4B-it-abliterated-FP8-Dynamic verifier + RedHatAI EAGLE-3 drafter, 50k Magpie samples, 1 epoch on a single DGX Spark GB10.
  • Before the patch: ~9h crash at step ~9485 with the BFloat16 / Float dtype mismatch.
  • After the patch: 11h run completed cleanly, pos 3 acceptance climbs to 72.7%.

Writeup with full numbers: https://ai-muninn.com/en/blog/dgx-spark-eagle3-finetune-abliterated-round1

🤖 Generated with Claude Code

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 16, 2026

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f3fd2996-1799-4d2d-88be-7e3eca7a24f7

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Walkthrough

The PR adds an optional dtype parameter to create_empty_sample in the training data module, defaulting to torch.bfloat16. The function now explicitly constructs empty tensors using this dtype for hidden states while maintaining torch.long for input IDs.

Changes

Empty Sample Dtype Control

Layer / File(s) Summary
create_empty_sample dtype parameter
src/speculators/train/data.py
Function signature gains optional dtype parameter (default torch.bfloat16), and implementation uses it to construct hidden_states and verifier_last_hidden_states tensors with consistent dtype, while input_ids remains torch.long.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

Suggested labels

bug, training

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: adding a bfloat16 default dtype parameter to create_empty_sample to fix dtype mismatches in training.
Description check ✅ Passed The description is directly related to the changeset, clearly explaining the problem (dtype mismatch in empty samples), the solution (adding dtype parameter), and providing test results.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot added bug Something isn't working training labels May 16, 2026
…ining dtype

When vLLM hidden-state extraction times out during training, the training
loop substitutes an empty sample via `create_empty_sample()`. The original
implementation built `torch.empty(0, ...)` tensors with no `dtype` argument,
so PyTorch fell back to `float32`.

Downstream EAGLE-3 layers (`fc`, `verifier_lm_head`) load `bfloat16` weights
when training a bf16 verifier. The first time the empty sample reached one
of those layers we got an explicit dtype mismatch crash, taking the whole
job down at a random late-epoch step.

This patch:

- Adds a `dtype: torch.dtype = torch.bfloat16` keyword argument (covers the
  common case of training against a bf16 verifier).
- Threads it through the `hidden_states` and `verifier_last_hidden_states`
  tensors so the empty placeholders match downstream weight dtype.
- Also pins `input_ids` to `torch.long` (it was previously default float).

Reproducer: train an EAGLE-3 drafter against any bf16 verifier with
`extract_hidden_states` + `ExampleHiddenStatesConnector`. Sufficient vLLM
extraction time-outs eventually surface the empty-sample path and the run
crashes with `RuntimeError: ... expected Float, got BFloat16`. With this
default, the empty sample flows through the bf16 layers cleanly.

Callers that train against a non-bf16 verifier can override the dtype
explicitly.

Signed-off-by: Thor Lin <coolthor@gmail.com>
@coolthor coolthor force-pushed the fix-create-empty-sample-bf16-default branch from 0d69f0f to 1f8e682 Compare May 16, 2026 16:39
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/speculators/train/data.py (1)

434-447: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Dtype inconsistency: create_collate_fn cannot override create_empty_sample dtype, but dataset defaults differ.

The create_empty_sample call at line 446 uses the default torch.bfloat16, but create_collate_fn has no way to pass a different dtype. This creates a critical inconsistency:

  1. BaseDataset and ArrowDataset default hidden_states_dtype=torch.float (lines 117, 184)
  2. create_empty_sample defaults to dtype=torch.bfloat16 (line 67)
  3. The dataset's dtype conversion (line 145) only applies to samples from __getitem__, NOT to empty samples created in the collate function

When all samples in a batch fail (e.g., vLLM timeouts), the collate function creates a bf16 empty sample that won't match the dataset's configured dtype, potentially causing crashes for non-bf16 training.

The PR objectives state "Callers can override the default dtype for non-bf16 verifiers," but this caller cannot.

Recommended fixes:

  • Option 1 (preferred): Add a dtype parameter to create_collate_fn and thread it through to create_empty_sample.
  • Option 2: Change dataset defaults to hidden_states_dtype=torch.bfloat16 to match the new empty sample default.
  • Option 3: Make the collate function introspect the batch to determine the correct dtype.
💡 Proposed fix: Thread dtype through create_collate_fn
 def create_collate_fn(
     max_len: int,
     hidden_size: int,
+    dtype: torch.dtype = torch.bfloat16,
     preprocess: Callable[[BatchType], BatchType] | None = None,
 ):
     def collate_fn(batch: list[BatchType | None]) -> BatchType:
         # Apply per-sample preprocessing and filter failed samples
         batch = [preprocess(b) if preprocess else b for b in batch if b is not None]
 
         if not batch:
             # Create empty sample which then gets padded to full
             # batch size if no valid samples are found
-            batch = [create_empty_sample(hidden_size)]
+            batch = [create_empty_sample(hidden_size, dtype=dtype)]
 
         collated_data = {}

Then callers of create_collate_fn should pass the dataset's hidden_states_dtype.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/train/data.py` around lines 434 - 447, create_collate_fn can
emit an empty sample with the wrong dtype because it always calls
create_empty_sample() which defaults to torch.bfloat16; add a dtype parameter to
create_collate_fn (e.g., dtype: torch.dtype | None) and pass it through to
create_empty_sample(dtype=...) when creating the empty sample, then update all
call sites (where collate_fn is constructed) to pass the dataset's
hidden_states_dtype (from BaseDataset / ArrowDataset) so empty samples match the
dataset conversion; ensure the collate_fn still uses existing preprocess
behavior and handles None appropriately.
🧹 Nitpick comments (1)
src/speculators/train/data.py (1)

67-88: ⚡ Quick win

Good explicit dtype control, but consider making loss_mask dtype explicit too.

The explicit dtype parameter and its application to hidden_states and verifier_last_hidden_states is a solid improvement. The explicit torch.long for input_ids is also good practice.

However, loss_mask (line 85) still uses implicit dtype defaulting to float32. While this is likely correct, being explicit would improve clarity and prevent future confusion.

📝 Optional: Make loss_mask dtype explicit
-        "loss_mask": torch.empty(0),
+        "loss_mask": torch.empty(0, dtype=torch.float32),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/train/data.py` around lines 67 - 88, create_empty_sample
currently leaves loss_mask as torch.empty(0) which defaults to float32
implicitly; change the loss_mask creation to explicitly specify its dtype (e.g.,
torch.empty(0, dtype=torch.float32)) so the function clearly documents the
intended type and avoids future ambiguity—update the "loss_mask" entry in
create_empty_sample accordingly (alongside the existing explicit dtypes for
hidden_states, verifier_last_hidden_states, input_ids, lengths, and
position_ids).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@src/speculators/train/data.py`:
- Around line 434-447: create_collate_fn can emit an empty sample with the wrong
dtype because it always calls create_empty_sample() which defaults to
torch.bfloat16; add a dtype parameter to create_collate_fn (e.g., dtype:
torch.dtype | None) and pass it through to create_empty_sample(dtype=...) when
creating the empty sample, then update all call sites (where collate_fn is
constructed) to pass the dataset's hidden_states_dtype (from BaseDataset /
ArrowDataset) so empty samples match the dataset conversion; ensure the
collate_fn still uses existing preprocess behavior and handles None
appropriately.

---

Nitpick comments:
In `@src/speculators/train/data.py`:
- Around line 67-88: create_empty_sample currently leaves loss_mask as
torch.empty(0) which defaults to float32 implicitly; change the loss_mask
creation to explicitly specify its dtype (e.g., torch.empty(0,
dtype=torch.float32)) so the function clearly documents the intended type and
avoids future ambiguity—update the "loss_mask" entry in create_empty_sample
accordingly (alongside the existing explicit dtypes for hidden_states,
verifier_last_hidden_states, input_ids, lengths, and position_ids).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8ad4a5db-24cf-45c7-af06-0d50c1053a4a

📥 Commits

Reviewing files that changed from the base of the PR and between 4f80f5d and 0d69f0f.

📒 Files selected for processing (1)
  • src/speculators/train/data.py

coolthor added a commit to coolthor/speculators that referenced this pull request May 16, 2026
…ty_sample

Address CodeRabbit review feedback on PR vllm-project#527: the caller of
create_empty_sample — create_collate_fn — had no way to pass a dtype
through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would
still hit a dtype mismatch when the empty-sample fallback fires.

Add a dtype keyword argument to create_collate_fn defaulting to
torch.bfloat16, and thread it into the create_empty_sample call so callers
can match the empty placeholder to their training precision.

The bf16 default keeps existing bf16 training paths unchanged while
unblocking non-bf16 callers cleanly.

Signed-off-by: Thor Lin <coolthor@gmail.com>
@coolthor
Copy link
Copy Markdown
Author

Thanks @coderabbitai — good catch on create_collate_fn being the only call site and not having a dtype kwarg of its own. Applied Option 1 in commit e8b6576: added dtype: torch.dtype = torch.bfloat16 to create_collate_fn and threaded it through to the create_empty_sample call. The bf16 default keeps existing training paths unchanged; non-bf16 callers can now override cleanly.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 16, 2026

Tip

For best results, initiate chat on the files or code changes.

@coolthor: Great — threading dtype all the way through create_collate_fn is the clean solution. The API surface is now consistent: a caller that trains against a float32 (or float16) verifier can pass dtype=torch.float32 once at the collate-fn level and have it propagate automatically to the empty-sample fallback, with no hidden float32 surprises buried inside create_empty_sample. Thanks for the quick follow-up!

(╯°□°)╯ ✨🐰


If you found this review helpful, would you consider giving us a shout-out on X?

Thank you for using CodeRabbit!

…ty_sample

Address CodeRabbit review feedback on PR vllm-project#527: the caller of
create_empty_sample — create_collate_fn — had no way to pass a dtype
through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would
still hit a dtype mismatch when the empty-sample fallback fires.

Add a dtype keyword argument to create_collate_fn defaulting to
torch.bfloat16, and thread it into the create_empty_sample call so callers
can match the empty placeholder to their training precision.

The bf16 default keeps existing bf16 training paths unchanged while
unblocking non-bf16 callers cleanly.

Signed-off-by: coolthor <coolthor@gmail.com>
@coolthor coolthor force-pushed the fix-create-empty-sample-bf16-default branch from e8b6576 to 5f4b33c Compare May 16, 2026 17:01
coolthor added a commit to coolthor/ai-muninn that referenced this pull request May 16, 2026
…, add bilingual video

- TLDR rewrite (zh-TW + en): "好消息 — Part 28 留下的壞消息,這次處理掉了" / "Good news — the Part 28 bottleneck is dealt with". Three paragraphs (setup → result → side bug) replacing the previous spec-sheet style.
- Stale PR-status text removed: all "PR in preparation / 準備中 / 還沒開" mentions (FAQ, TLDR, body, references section) now link directly to the live upstream PR vllm-project/speculators#527 in both languages.
- Novelty block softened per Codex review: dropped "most-downloaded / high-leverage point to target" framing, added explicit caveat that Round 1's headline numbers are a single-config improvement against a re-used Part 28 baseline, with a paired same-prompt rerun + no-abliteration control queued for Round 2.
- HF README on coolthor/Huihui-...-eagle3-draft now points back at Part 28/29/30 in both languages (bidirectional traffic; pushed via huggingface_hub).
- Remotion explainer video added in both languages: mechanism-first 3-scene (Part 28 collapse curve → Round 1 retrain pipeline → flattened acceptance + 2.0x throughput). EN video uses -en.mp4 suffix per Part 28 convention.

🤖 Generated with [Claude Code](https://claude.com/claude-code)


def create_empty_sample(hidden_size: int):
def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of setting the default to bfloat16, could we infer dtype from the model? It's not super intuitive to override the default.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually set this explicitly when creating the collate function. It just isn't being piped in yet.

This is where create_collate_fn is called:

collate_fn=create_collate_fn(args.total_seq_len, hidden_size, preprocess),

We explicitly set the hidden state dtype in scripts/train.py here:

hidden_states_dtype=hidden_states_dtype,

This is then set as an attribute on the dataset object. So if you either get hidden_states_dtype from the dataset object passed into setup_dataloader or pass it into the fn directly, you can then pipe it into the create_collate_fn call.

Copy link
Copy Markdown
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small comment. Otherwise looks good!



def create_empty_sample(hidden_size: int):
def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually set this explicitly when creating the collate function. It just isn't being piped in yet.

This is where create_collate_fn is called:

collate_fn=create_collate_fn(args.total_seq_len, hidden_size, preprocess),

We explicitly set the hidden state dtype in scripts/train.py here:

hidden_states_dtype=hidden_states_dtype,

This is then set as an attribute on the dataset object. So if you either get hidden_states_dtype from the dataset object passed into setup_dataloader or pass it into the fn directly, you can then pipe it into the create_collate_fn call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working training

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants