Skip to content

Commit e8b6576

Browse files
committed
fix(train/data): thread dtype through create_collate_fn to create_empty_sample
Address CodeRabbit review feedback on PR #527: the caller of create_empty_sample — create_collate_fn — had no way to pass a dtype through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would still hit a dtype mismatch when the empty-sample fallback fires. Add a dtype keyword argument to create_collate_fn defaulting to torch.bfloat16, and thread it into the create_empty_sample call so callers can match the empty placeholder to their training precision. The bf16 default keeps existing bf16 training paths unchanged while unblocking non-bf16 callers cleanly. Signed-off-by: Thor Lin <coolthor@gmail.com>
1 parent 1f8e682 commit e8b6576

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/speculators/train/data.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def _get_raw_data(self, index):
434434
def create_collate_fn(
435435
max_len: int,
436436
hidden_size: int,
437+
dtype: torch.dtype = torch.bfloat16,
437438
preprocess: Callable[[BatchType], BatchType] | None = None,
438439
):
439440
def collate_fn(batch: list[BatchType | None]) -> BatchType:
@@ -442,8 +443,11 @@ def collate_fn(batch: list[BatchType | None]) -> BatchType:
442443

443444
if not batch:
444445
# Create empty sample which then gets padded to full
445-
# batch size if no valid samples are found
446-
batch = [create_empty_sample(hidden_size)]
446+
# batch size if no valid samples are found.
447+
# Match the configured `dtype` so the placeholder doesn't crash
448+
# downstream layers loaded at a different precision (e.g. bf16
449+
# weights vs fp32 default placeholders).
450+
batch = [create_empty_sample(hidden_size, dtype=dtype)]
447451

448452
collated_data = {}
449453
for key in batch[0]: # type: ignore[union-attr]

0 commit comments

Comments
 (0)