Commit e8b6576
committed
fix(train/data): thread dtype through create_collate_fn to create_empty_sample
Address CodeRabbit review feedback on PR #527: the caller of
create_empty_sample — create_collate_fn — had no way to pass a dtype
through, so non-bf16 trainers (e.g. fp32 verifiers, fp16 verifiers) would
still hit a dtype mismatch when the empty-sample fallback fires.
Add a dtype keyword argument to create_collate_fn defaulting to
torch.bfloat16, and thread it into the create_empty_sample call so callers
can match the empty placeholder to their training precision.
The bf16 default keeps existing bf16 training paths unchanged while
unblocking non-bf16 callers cleanly.
Signed-off-by: Thor Lin <coolthor@gmail.com>1 parent 1f8e682 commit e8b6576
1 file changed
Lines changed: 6 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
434 | 434 | | |
435 | 435 | | |
436 | 436 | | |
| 437 | + | |
437 | 438 | | |
438 | 439 | | |
439 | 440 | | |
| |||
442 | 443 | | |
443 | 444 | | |
444 | 445 | | |
445 | | - | |
446 | | - | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
447 | 451 | | |
448 | 452 | | |
449 | 453 | | |
| |||
0 commit comments