1- """Merge a Megatron LoRA adapter checkpoint with its base model and export to HuggingFace format.
1+ """Export a Megatron LoRA adapter checkpoint to HuggingFace format.
22
3- This is helpful when one wants to train the model using Megatron with LoRA adapter and then convert it to HuggingFace format
4- for inference and evaluation.
3+ This script supports two workflows:
4+
5+ 1. Merge the base model and LoRA adapter, then export a standard HuggingFace model.
6+ 2. Export only the LoRA adapter to a HuggingFace PEFT-compatible directory without merging.
57
68Usage (requires mcore extra):
79
10+ # Export adapter only (recommended when you want PEFT format)
11+ uv run --extra mcore python examples/converters/convert_lora_to_hf.py \
12+ --adapter-only \
13+ --adapter-ckpt results/dpo_glm5/step_5/policy/weights/iter_0000000 \
14+ --hf-model-name zai-org/GLM-5 \
15+ --hf-ckpt-path ./hf_lora_adapter
16+
17+ # Merge base model + adapter and export a full HF checkpoint
818 uv run --extra mcore python examples/converters/convert_lora_to_hf.py \
919 --base-ckpt ~/.cache/huggingface/nemo_rl/zai-org/GLM-5/iter_0000000 \
1020 --adapter-ckpt results/dpo_glm5/step_5/policy/weights/iter_0000000 \
2939
3040def parse_args ():
3141 parser = argparse .ArgumentParser (
32- description = "Merge Megatron LoRA adapter with base model and export to HF "
42+ description = "Export Megatron LoRA checkpoint to HuggingFace format "
3343 )
3444 parser .add_argument (
3545 "--base-ckpt" ,
3646 type = str ,
37- required = True ,
38- help = "Path to base model Megatron checkpoint (iter_XXXXXXX directory)" ,
47+ default = None ,
48+ help = "Path to base model Megatron checkpoint (iter_XXXXXXX directory). Required unless --adapter-only is set. " ,
3949 )
4050 parser .add_argument (
4151 "--adapter-ckpt" ,
@@ -53,9 +63,47 @@ def parse_args():
5363 "--hf-ckpt-path" ,
5464 type = str ,
5565 required = True ,
56- help = "Output path for merged HF checkpoint" ,
66+ help = "Output path for the exported HF checkpoint or adapter directory " ,
5767 )
58- return parser .parse_args ()
68+ parser .add_argument (
69+ "--adapter-only" ,
70+ action = "store_true" ,
71+ help = "Export only the LoRA adapter in HuggingFace PEFT format without merging into the base model." ,
72+ )
73+ args = parser .parse_args ()
74+ if not args .adapter_only and not args .base_ckpt :
75+ parser .error ("--base-ckpt is required unless --adapter-only is set" )
76+ return args
77+
78+
79+ def export_lora_adapter_to_hf (
80+ adapter_ckpt : str ,
81+ hf_model_name : str ,
82+ hf_ckpt_path : str ,
83+ ) -> str :
84+ """Export a Megatron LoRA checkpoint to HuggingFace PEFT adapter format.
85+
86+ Args:
87+ adapter_ckpt: Path to the LoRA adapter Megatron checkpoint (iter_XXXXXXX directory).
88+ hf_model_name: HuggingFace model identifier for the base model.
89+ hf_ckpt_path: Output directory for the HuggingFace PEFT adapter files.
90+
91+ Returns:
92+ The *hf_ckpt_path* that was written to.
93+
94+ Raises:
95+ FileExistsError: If *hf_ckpt_path* already exists.
96+ """
97+ if os .path .exists (hf_ckpt_path ):
98+ raise FileExistsError (f"Output path already exists: { hf_ckpt_path } " )
99+
100+ from megatron .bridge import AutoBridge
101+
102+ bridge = AutoBridge .from_hf_pretrained (hf_model_name , trust_remote_code = True )
103+ logger .info ("Exporting LoRA adapter in HuggingFace PEFT format..." )
104+ bridge .export_adapter_ckpt (adapter_ckpt , hf_ckpt_path )
105+ logger .info (f"Done! HF adapter saved to: { hf_ckpt_path } " )
106+ return hf_ckpt_path
59107
60108
61109def merge_lora_to_hf (
@@ -86,13 +134,16 @@ def merge_lora_to_hf(
86134 from megatron .bridge import AutoBridge
87135 from megatron .bridge .peft .lora import LoRA
88136 from megatron .bridge .training .checkpointing import (
137+ _generate_model_state_dict ,
89138 _load_model_weights_from_checkpoint ,
139+ apply_peft_adapter_filter_to_state_dict ,
90140 )
91141 from megatron .bridge .training .model_load_save import (
92142 load_model_config ,
93143 megatron_cpu_init_context ,
94144 temporary_distributed_context ,
95145 )
146+ from megatron .core import dist_checkpointing
96147
97148 bridge = AutoBridge .from_hf_pretrained (hf_model_name , trust_remote_code = True )
98149
@@ -140,9 +191,10 @@ def merge_lora_to_hf(
140191 lora_B_init_method = peft_section .get ("lora_B_init_method" , "zero" ),
141192 a2a_experimental = peft_section .get ("a2a_experimental" , False ),
142193 )
143- model_cfg .peft = peft
144194
145- logger .info ("Building model with LoRA wrappers on CPU..." )
195+ logger .info (
196+ "Building base model on CPU (LoRA wrappers applied after base weights are loaded)..."
197+ )
146198 if hasattr (model_cfg , "finalize" ):
147199 model_cfg .finalize ()
148200 with megatron_cpu_init_context (model_cfg ):
@@ -159,8 +211,25 @@ def merge_lora_to_hf(
159211 _load_model_weights_from_checkpoint (base_ckpt , megatron_model , strict = False )
160212 gc .collect ()
161213
214+ logger .info ("Applying LoRA wrappers to model..." )
215+ megatron_model = peft (megatron_model , training = False )
216+ gc .collect ()
217+
162218 logger .info (f"Loading LoRA adapter from { adapter_ckpt } ..." )
163- _load_model_weights_from_checkpoint (adapter_ckpt , megatron_model , strict = False )
219+ adapter_sharded_state_dict = _generate_model_state_dict (megatron_model , {})
220+ adapter_sharded_state_dict = apply_peft_adapter_filter_to_state_dict (
221+ adapter_sharded_state_dict , peft
222+ )
223+ loaded_adapter_state_dict = dist_checkpointing .load (
224+ adapter_sharded_state_dict , adapter_ckpt
225+ )
226+ model_key = (
227+ "model"
228+ if "model" in loaded_adapter_state_dict
229+ else next (k for k in loaded_adapter_state_dict if k .startswith ("model" ))
230+ )
231+ for m in megatron_model :
232+ m .load_state_dict (loaded_adapter_state_dict [model_key ], strict = False )
164233 gc .collect ()
165234
166235 logger .info ("Saving merged model in HuggingFace format..." )
@@ -183,12 +252,19 @@ def merge_lora_to_hf(
183252
184253def main ():
185254 args = parse_args ()
186- merge_lora_to_hf (
187- base_ckpt = args .base_ckpt ,
188- adapter_ckpt = args .adapter_ckpt ,
189- hf_model_name = args .hf_model_name ,
190- hf_ckpt_path = args .hf_ckpt_path ,
191- )
255+ if args .adapter_only :
256+ export_lora_adapter_to_hf (
257+ adapter_ckpt = args .adapter_ckpt ,
258+ hf_model_name = args .hf_model_name ,
259+ hf_ckpt_path = args .hf_ckpt_path ,
260+ )
261+ else :
262+ merge_lora_to_hf (
263+ base_ckpt = args .base_ckpt ,
264+ adapter_ckpt = args .adapter_ckpt ,
265+ hf_model_name = args .hf_model_name ,
266+ hf_ckpt_path = args .hf_ckpt_path ,
267+ )
192268
193269
194270if __name__ == "__main__" :
0 commit comments