You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix bf16 dtype mismatch in ZeRO-3 with zero_quantized_weights
When using ZeRO-3 with zero_quantized_weights=True and bf16 enabled,
the dequantized weights were incorrectly cast to fp16 instead of
preserving the original bf16 dtype. This caused RuntimeError during
training with BERT and similar models.
The fix adds original_dtype tracking to AllGatherCoalescedHandle,
mirroring the existing pattern in AllGatherHandle, to ensure weights
are converted back to their original dtype after dequantization.
Fixes#7775
0 commit comments