Skip to content

Commit 43f4642

Browse files
committed
feat: enhance convert_lora_to_hf script to support exporting LoRA adapters in HuggingFace PEFT format
Signed-off-by: ruit <ruit@nvidia.com>
1 parent 0baabd5 commit 43f4642

1 file changed

Lines changed: 93 additions & 17 deletions

File tree

examples/converters/convert_lora_to_hf.py

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1-
"""Merge a Megatron LoRA adapter checkpoint with its base model and export to HuggingFace format.
1+
"""Export a Megatron LoRA adapter checkpoint to HuggingFace format.
22
3-
This is helpful when one wants to train the model using Megatron with LoRA adapter and then convert it to HuggingFace format
4-
for inference and evaluation.
3+
This script supports two workflows:
4+
5+
1. Merge the base model and LoRA adapter, then export a standard HuggingFace model.
6+
2. Export only the LoRA adapter to a HuggingFace PEFT-compatible directory without merging.
57
68
Usage (requires mcore extra):
79
10+
# Export adapter only (recommended when you want PEFT format)
11+
uv run --extra mcore python examples/converters/convert_lora_to_hf.py \
12+
--adapter-only \
13+
--adapter-ckpt results/dpo_glm5/step_5/policy/weights/iter_0000000 \
14+
--hf-model-name zai-org/GLM-5 \
15+
--hf-ckpt-path ./hf_lora_adapter
16+
17+
# Merge base model + adapter and export a full HF checkpoint
818
uv run --extra mcore python examples/converters/convert_lora_to_hf.py \
919
--base-ckpt ~/.cache/huggingface/nemo_rl/zai-org/GLM-5/iter_0000000 \
1020
--adapter-ckpt results/dpo_glm5/step_5/policy/weights/iter_0000000 \
@@ -29,13 +39,13 @@
2939

3040
def parse_args():
3141
parser = argparse.ArgumentParser(
32-
description="Merge Megatron LoRA adapter with base model and export to HF"
42+
description="Export Megatron LoRA checkpoint to HuggingFace format"
3343
)
3444
parser.add_argument(
3545
"--base-ckpt",
3646
type=str,
37-
required=True,
38-
help="Path to base model Megatron checkpoint (iter_XXXXXXX directory)",
47+
default=None,
48+
help="Path to base model Megatron checkpoint (iter_XXXXXXX directory). Required unless --adapter-only is set.",
3949
)
4050
parser.add_argument(
4151
"--adapter-ckpt",
@@ -53,9 +63,47 @@ def parse_args():
5363
"--hf-ckpt-path",
5464
type=str,
5565
required=True,
56-
help="Output path for merged HF checkpoint",
66+
help="Output path for the exported HF checkpoint or adapter directory",
5767
)
58-
return parser.parse_args()
68+
parser.add_argument(
69+
"--adapter-only",
70+
action="store_true",
71+
help="Export only the LoRA adapter in HuggingFace PEFT format without merging into the base model.",
72+
)
73+
args = parser.parse_args()
74+
if not args.adapter_only and not args.base_ckpt:
75+
parser.error("--base-ckpt is required unless --adapter-only is set")
76+
return args
77+
78+
79+
def export_lora_adapter_to_hf(
80+
adapter_ckpt: str,
81+
hf_model_name: str,
82+
hf_ckpt_path: str,
83+
) -> str:
84+
"""Export a Megatron LoRA checkpoint to HuggingFace PEFT adapter format.
85+
86+
Args:
87+
adapter_ckpt: Path to the LoRA adapter Megatron checkpoint (iter_XXXXXXX directory).
88+
hf_model_name: HuggingFace model identifier for the base model.
89+
hf_ckpt_path: Output directory for the HuggingFace PEFT adapter files.
90+
91+
Returns:
92+
The *hf_ckpt_path* that was written to.
93+
94+
Raises:
95+
FileExistsError: If *hf_ckpt_path* already exists.
96+
"""
97+
if os.path.exists(hf_ckpt_path):
98+
raise FileExistsError(f"Output path already exists: {hf_ckpt_path}")
99+
100+
from megatron.bridge import AutoBridge
101+
102+
bridge = AutoBridge.from_hf_pretrained(hf_model_name, trust_remote_code=True)
103+
logger.info("Exporting LoRA adapter in HuggingFace PEFT format...")
104+
bridge.export_adapter_ckpt(adapter_ckpt, hf_ckpt_path)
105+
logger.info(f"Done! HF adapter saved to: {hf_ckpt_path}")
106+
return hf_ckpt_path
59107

60108

61109
def merge_lora_to_hf(
@@ -86,13 +134,16 @@ def merge_lora_to_hf(
86134
from megatron.bridge import AutoBridge
87135
from megatron.bridge.peft.lora import LoRA
88136
from megatron.bridge.training.checkpointing import (
137+
_generate_model_state_dict,
89138
_load_model_weights_from_checkpoint,
139+
apply_peft_adapter_filter_to_state_dict,
90140
)
91141
from megatron.bridge.training.model_load_save import (
92142
load_model_config,
93143
megatron_cpu_init_context,
94144
temporary_distributed_context,
95145
)
146+
from megatron.core import dist_checkpointing
96147

97148
bridge = AutoBridge.from_hf_pretrained(hf_model_name, trust_remote_code=True)
98149

@@ -140,9 +191,10 @@ def merge_lora_to_hf(
140191
lora_B_init_method=peft_section.get("lora_B_init_method", "zero"),
141192
a2a_experimental=peft_section.get("a2a_experimental", False),
142193
)
143-
model_cfg.peft = peft
144194

145-
logger.info("Building model with LoRA wrappers on CPU...")
195+
logger.info(
196+
"Building base model on CPU (LoRA wrappers applied after base weights are loaded)..."
197+
)
146198
if hasattr(model_cfg, "finalize"):
147199
model_cfg.finalize()
148200
with megatron_cpu_init_context(model_cfg):
@@ -159,8 +211,25 @@ def merge_lora_to_hf(
159211
_load_model_weights_from_checkpoint(base_ckpt, megatron_model, strict=False)
160212
gc.collect()
161213

214+
logger.info("Applying LoRA wrappers to model...")
215+
megatron_model = peft(megatron_model, training=False)
216+
gc.collect()
217+
162218
logger.info(f"Loading LoRA adapter from {adapter_ckpt}...")
163-
_load_model_weights_from_checkpoint(adapter_ckpt, megatron_model, strict=False)
219+
adapter_sharded_state_dict = _generate_model_state_dict(megatron_model, {})
220+
adapter_sharded_state_dict = apply_peft_adapter_filter_to_state_dict(
221+
adapter_sharded_state_dict, peft
222+
)
223+
loaded_adapter_state_dict = dist_checkpointing.load(
224+
adapter_sharded_state_dict, adapter_ckpt
225+
)
226+
model_key = (
227+
"model"
228+
if "model" in loaded_adapter_state_dict
229+
else next(k for k in loaded_adapter_state_dict if k.startswith("model"))
230+
)
231+
for m in megatron_model:
232+
m.load_state_dict(loaded_adapter_state_dict[model_key], strict=False)
164233
gc.collect()
165234

166235
logger.info("Saving merged model in HuggingFace format...")
@@ -183,12 +252,19 @@ def merge_lora_to_hf(
183252

184253
def main():
185254
args = parse_args()
186-
merge_lora_to_hf(
187-
base_ckpt=args.base_ckpt,
188-
adapter_ckpt=args.adapter_ckpt,
189-
hf_model_name=args.hf_model_name,
190-
hf_ckpt_path=args.hf_ckpt_path,
191-
)
255+
if args.adapter_only:
256+
export_lora_adapter_to_hf(
257+
adapter_ckpt=args.adapter_ckpt,
258+
hf_model_name=args.hf_model_name,
259+
hf_ckpt_path=args.hf_ckpt_path,
260+
)
261+
else:
262+
merge_lora_to_hf(
263+
base_ckpt=args.base_ckpt,
264+
adapter_ckpt=args.adapter_ckpt,
265+
hf_model_name=args.hf_model_name,
266+
hf_ckpt_path=args.hf_ckpt_path,
267+
)
192268

193269

194270
if __name__ == "__main__":

0 commit comments

Comments
 (0)