File tree Expand file tree Collapse file tree
src/megatron/bridge/models Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -583,13 +583,20 @@ def get_model(
583583 model_module .cuda (torch .cuda .current_device ())
584584
585585 if (model_config .fp16 or model_config .bf16 ) and mixed_precision_wrapper is not None :
586- model = [mixed_precision_wrapper (model_config , model_module ) for model_module in model ]
587-
588- # Maintain expert bias in float32 wrapped in Float16Module
586+ # Save expert bias in float32 to avoid precision loss during conversion
587+ keep_in_fp32 = []
589588 for model_module in model :
590589 for submodule in model_module .modules ():
591590 if hasattr (submodule , "_maintain_float32_expert_bias" ):
592- submodule ._maintain_float32_expert_bias ()
591+ expert_bias = getattr (submodule , "expert_bias" , None )
592+ if expert_bias is not None :
593+ keep_in_fp32 .append ((submodule , expert_bias .data .clone ()))
594+
595+ model = [mixed_precision_wrapper (model_config , model_module ) for model_module in model ]
596+
597+ # Restore expert bias to float32
598+ for submodule , fp32_data in keep_in_fp32 :
599+ submodule .expert_bias .data = fp32_data
593600
594601 if correct_amax_history_if_needed is not None :
595602 correct_amax_history_if_needed (model )
You can’t perform that action at this time.
0 commit comments