diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index b1e05c91..5007c222 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -64,7 +64,7 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0): StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]] -def create_empty_sample(hidden_size: int): +def create_empty_sample(hidden_size: int, dtype: torch.dtype = torch.bfloat16): # data structure: { # "hidden_states": [seq_len, 3 * hidden_size], # "input_ids": [seq_len], @@ -73,11 +73,15 @@ def create_empty_sample(hidden_size: int): # "lengths": [1], # "position_ids": [seq_len], # } + # Default dtype is bfloat16 to match the hidden_states dtype used downstream. + # When this fallback is used (e.g. vLLM hidden-state extraction times out and + # we substitute an empty sample), the implicit float32 placeholders crashed + # bf16 EAGLE-3 layers (fc, verifier_lm_head) with a dtype mismatch. return { - "hidden_states": torch.empty(0, 3 * hidden_size), - "input_ids": torch.empty(0), - "verifier_last_hidden_states": torch.empty(0, hidden_size), + "hidden_states": torch.empty(0, 3 * hidden_size, dtype=dtype), + "input_ids": torch.empty(0, dtype=torch.long), + "verifier_last_hidden_states": torch.empty(0, hidden_size, dtype=dtype), "loss_mask": torch.empty(0), "lengths": torch.tensor([0], dtype=torch.long), "position_ids": torch.arange(0, dtype=torch.long), @@ -430,6 +434,7 @@ def _get_raw_data(self, index): def create_collate_fn( max_len: int, hidden_size: int, + dtype: torch.dtype = torch.bfloat16, preprocess: Callable[[BatchType], BatchType] | None = None, ): def collate_fn(batch: list[BatchType | None]) -> BatchType: @@ -438,8 +443,11 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType: if not batch: # Create empty sample which then gets padded to full - # batch size if no valid samples are found - batch = [create_empty_sample(hidden_size)] + # batch size if no valid samples are found. + # Match the configured `dtype` so the placeholder doesn't crash + # downstream layers loaded at a different precision (e.g. bf16 + # weights vs fp32 default placeholders). + batch = [create_empty_sample(hidden_size, dtype=dtype)] collated_data = {} for key in batch[0]: # type: ignore[union-attr]