Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/speculators/train/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0):
StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]]


def create_empty_sample(hidden_size: int):
def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of setting the default to bfloat16, could we infer dtype from the model? It's not super intuitive to override the default.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually set this explicitly when creating the collate function. It just isn't being piped in yet.

This is where create_collate_fn is called:

collate_fn=create_collate_fn(args.total_seq_len, hidden_size, preprocess),

We explicitly set the hidden state dtype in scripts/train.py here:

hidden_states_dtype=hidden_states_dtype,

This is then set as an attribute on the dataset object. So if you either get hidden_states_dtype from the dataset object passed into setup_dataloader or pass it into the fn directly, you can then pipe it into the create_collate_fn call.

# data structure: {
# "hidden_states": [seq_len, 3 * hidden_size],
# "input_ids": [seq_len],
Expand All @@ -73,11 +73,15 @@ def create_empty_sample(hidden_size: int):
# "lengths": [1],
# "position_ids": [seq_len],
# }
# Default dtype is bfloat16 to match the hidden_states dtype used downstream.
# When this fallback is used (e.g. vLLM hidden-state extraction times out and
# we substitute an empty sample), the implicit float32 placeholders crashed
# bf16 EAGLE-3 layers (fc, verifier_lm_head) with a dtype mismatch.

return {
"hidden_states": torch.empty(0, 3 * hidden_size),
"input_ids": torch.empty(0),
"verifier_last_hidden_states": torch.empty(0, hidden_size),
"hidden_states": torch.empty(0, 3 * hidden_size, dtype=dtype),
"input_ids": torch.empty(0, dtype=torch.long),
"verifier_last_hidden_states": torch.empty(0, hidden_size, dtype=dtype),
"loss_mask": torch.empty(0),
"lengths": torch.tensor([0], dtype=torch.long),
"position_ids": torch.arange(0, dtype=torch.long),
Expand Down Expand Up @@ -430,6 +434,7 @@ def _get_raw_data(self, index):
def create_collate_fn(
max_len: int,
hidden_size: int,
dtype: torch.dtype = torch.bfloat16,
preprocess: Callable[[BatchType], BatchType] | None = None,
):
def collate_fn(batch: list[BatchType | None]) -> BatchType:
Expand All @@ -438,8 +443,11 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType:

if not batch:
# Create empty sample which then gets padded to full
# batch size if no valid samples are found
batch = [create_empty_sample(hidden_size)]
# batch size if no valid samples are found.
# Match the configured `dtype` so the placeholder doesn't crash
# downstream layers loaded at a different precision (e.g. bf16
# weights vs fp32 default placeholders).
batch = [create_empty_sample(hidden_size, dtype=dtype)]

collated_data = {}
for key in batch[0]: # type: ignore[union-attr]
Expand Down
Loading