Commit 50b175d
committed
Fix bf16 dtype mismatch in ZeRO-3 with zero_quantized_weights
When using ZeRO-3 with zero_quantized_weights=True and bf16 enabled,
the dequantized weights were incorrectly cast to fp16 instead of
preserving the original bf16 dtype. This caused RuntimeError during
training with BERT and similar models.
The fix adds original_dtype tracking to AllGatherCoalescedHandle,
mirroring the existing pattern in AllGatherHandle, to ensure weights
are converted back to their original dtype after dequantization.
Fixes #7775
Signed-off-by: juyterman1000 <fastrunner10090@gmail.com>1 parent 374f6d0 commit 50b175d
1 file changed
Lines changed: 237 additions & 202 deletions
0 commit comments