Skip to content

Commit dd9729f

Browse files
yfwgshennvm
andcommitted
Fix maintain fp32 expert bias
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-Authored-By: Gerald Shen <geshen@nvidia.com>
1 parent c7a03ca commit dd9729f

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

src/megatron/bridge/models/model_provider.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,13 +583,20 @@ def get_model(
583583
model_module.cuda(torch.cuda.current_device())
584584

585585
if (model_config.fp16 or model_config.bf16) and mixed_precision_wrapper is not None:
586-
model = [mixed_precision_wrapper(model_config, model_module) for model_module in model]
587-
588-
# Maintain expert bias in float32 wrapped in Float16Module
586+
# Save expert bias in float32 to avoid precision loss during conversion
587+
keep_in_fp32 = []
589588
for model_module in model:
590589
for submodule in model_module.modules():
591590
if hasattr(submodule, "_maintain_float32_expert_bias"):
592-
submodule._maintain_float32_expert_bias()
591+
expert_bias = getattr(submodule, "expert_bias", None)
592+
if expert_bias is not None:
593+
keep_in_fp32.append((submodule, expert_bias.data.clone()))
594+
595+
model = [mixed_precision_wrapper(model_config, model_module) for model_module in model]
596+
597+
# Restore expert bias to float32
598+
for submodule, fp32_data in keep_in_fp32:
599+
submodule.expert_bias.data = fp32_data
593600

594601
if correct_amax_history_if_needed is not None:
595602
correct_amax_history_if_needed(model)

0 commit comments

Comments
 (0)