Skip to content

Commit 50b175d

Browse files
committed
Fix bf16 dtype mismatch in ZeRO-3 with zero_quantized_weights
When using ZeRO-3 with zero_quantized_weights=True and bf16 enabled, the dequantized weights were incorrectly cast to fp16 instead of preserving the original bf16 dtype. This caused RuntimeError during training with BERT and similar models. The fix adds original_dtype tracking to AllGatherCoalescedHandle, mirroring the existing pattern in AllGatherHandle, to ensure weights are converted back to their original dtype after dequantization. Fixes #7775 Signed-off-by: juyterman1000 <fastrunner10090@gmail.com>
1 parent 374f6d0 commit 50b175d

1 file changed

Lines changed: 237 additions & 202 deletions

File tree

0 commit comments

Comments
 (0)