Skip to content

Commit d966544

Browse files
authored
[model] refactor: formalize hf_config on MegatronModelBridge (#3329)
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent 925c88f commit d966544

6 files changed

Lines changed: 11 additions & 29 deletions

File tree

src/megatron/bridge/models/conversion/model_bridge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ def mapping_registry(self) -> MegatronMappingRegistry:
287287
# Set this in bridge subclasses to include model-specific files beyond standard artifacts
288288
ADDITIONAL_FILE_PATTERNS = None
289289

290+
# HuggingFace PretrainedConfig, set by register_bridge_implementation dispatch.
291+
# Available in mapping_registry(), stream_weights_*(), and build_conversion_tasks().
292+
hf_config = None
293+
290294
# Common bidirectional config field name mapping: (hf_name, megatron_name)
291295
# Some mappings may not be used by all models - that's fine, unused fields are skipped
292296
CONFIG_MAPPING = [

src/megatron/bridge/models/deepseek/deepseek_v2_bridge.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,7 @@ def megatron_to_hf_config(cls, provider) -> dict:
111111

112112
return hf_cfg
113113

114-
def build_conversion_tasks(self, hf_pretrained, megatron_model):
115-
"""Override to store config before mapping_registry is called."""
116-
# Store config on instance for use in mapping_registry
117-
from transformers import PretrainedConfig
118-
119-
self._hf_config = hf_pretrained if isinstance(hf_pretrained, PretrainedConfig) else hf_pretrained.config
120-
return super().build_conversion_tasks(hf_pretrained, megatron_model)
121-
122114
def mapping_registry(self) -> MegatronMappingRegistry:
123-
# Get hf_config if available (set by build_conversion_tasks)
124-
hf_config = getattr(self, "_hf_config", None)
115+
hf_config = self.hf_config
125116
mapping_list = get_common_mapping_list(hf_config=hf_config)
126117
return MegatronMappingRegistry(*mapping_list)

src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,8 @@ def megatron_to_hf_config(cls, provider: MLAModelProvider) -> dict:
120120

121121
return hf_cfg
122122

123-
def build_conversion_tasks(self, hf_pretrained, megatron_model):
124-
"""Override to store config before mapping_registry is called."""
125-
# Store config on instance for use in mapping_registry
126-
from transformers import PretrainedConfig
127-
128-
self._hf_config = hf_pretrained if isinstance(hf_pretrained, PretrainedConfig) else hf_pretrained.config
129-
return super().build_conversion_tasks(hf_pretrained, megatron_model)
130-
131123
def mapping_registry(self) -> MegatronMappingRegistry:
132-
hf_config = getattr(self, "_hf_config", None)
124+
hf_config = self.hf_config
133125
mapping_list = get_common_mapping_list(hf_config=hf_config)
134126
mapping_list.append(
135127
AutoMapping(

src/megatron/bridge/models/glm/glm45_bridge.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider
100100
return provider
101101

102102
def build_conversion_tasks(self, hf_pretrained, megatron_model):
103-
"""Override to store config before mapping_registry is called."""
104-
from transformers import PretrainedConfig
105-
106-
# Store config on instance for use in mapping_registry
107-
self._hf_config = hf_pretrained if isinstance(hf_pretrained, PretrainedConfig) else hf_pretrained.config
103+
"""Override to store HF state source before mapping_registry is called."""
108104
has_state = hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")
109105
self._hf_state_source = hf_pretrained.state.source if has_state else None
110106
self._hf_keys = list(self._hf_state_source.get_all_keys()) if self._hf_state_source else None
@@ -208,10 +204,10 @@ def mapping_registry(self) -> MegatronMappingRegistry:
208204
]
209205
)
210206
# optionally add MTP mappings
211-
if not hasattr(self, "_hf_config"):
207+
if self.hf_config is None:
212208
logger.warning("No HF config found, skipping MTP mappings.")
213209
return MegatronMappingRegistry(*mapping_list)
214-
hf_config = self._hf_config
210+
hf_config = self.hf_config
215211
num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
216212
num_transformer_layers = hf_config.num_hidden_layers
217213
for mtp_layer in range(num_mtp_layers):

src/megatron/bridge/models/glm_vl/glm_45v_bridge.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> GLM45VModelProvider:
8989
return provider
9090

9191
def build_conversion_tasks(self, hf_pretrained, megatron_model):
92-
"""Override to store config before mapping_registry is called."""
93-
self._hf_config = hf_pretrained.config
92+
"""Override to store HF state source before mapping_registry is called."""
9493
self._hf_state_source = hf_pretrained.state.source
9594
self._hf_keys = list(self._hf_state_source.get_all_keys())
9695
return super().build_conversion_tasks(hf_pretrained, megatron_model)

src/megatron/bridge/models/qwen_audio/modeling_qwen2_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149

150150
# Store audio token id from config
151151
self.audio_token_id = getattr(config, "audio_token_id", 151646)
152-
self.pad_token_id = getattr(config.hf_config, "pad_token_id", -1)
152+
self.pad_token_id = getattr(config, "pad_token_id", -1)
153153

154154
def set_input_tensor(self, input_tensor) -> None:
155155
"""Set model chunk input tensor."""

0 commit comments

Comments
 (0)