Skip to content

Commit 6eba340

Browse files
authored
Merge pull request #32 from alex-jw-brooks/merge_models
Merge models
2 parents 517652d + fb3f2e7 commit 6eba340

1 file changed

Lines changed: 91 additions & 0 deletions

File tree

tuning/utils/merge_model_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Standard
2+
from typing import Union
3+
import argparse
4+
import json
5+
import os
6+
7+
# Third Party
8+
from peft import PeftModel
9+
from tqdm import tqdm
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
12+
13+
def create_merged_model(
14+
checkpoint_models: Union[str, list[str]],
15+
export_path: str = None,
16+
base_model: str = None,
17+
save_tokenizer: bool = True,
18+
):
19+
"""Given a base model & checkpoint model(s) which were tuned with lora, load into memory
20+
& create a merged model. If an export path is specified, write it to disk. If multiple
21+
checkpoint models are provided, we merge_and_unload() them one after the other, which
22+
combines them with equal weights.
23+
24+
TODO: In the future, it's probably a good idea to explore different combination schemes,
25+
which can likely be done using a combination of add_weighted_adapter() and merge_and_unload().
26+
27+
References:
28+
- https://github.com/huggingface/peft/issues/1040
29+
- https://github.com/huggingface/peft/issues/280#issuecomment-1500805831
30+
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter
31+
32+
Args:
33+
checkpoint_model: Union[str, list[str]]
34+
One or more lora checkpoints containing adapters.
35+
export_path: str
36+
Path to export the merged model to.
37+
base_model: str
38+
Base model to be leveraged. If no base model is specified, the base model is pulled
39+
from the checkpoint model's adapter config.
40+
save_tokenizer: bool
41+
Indicates whether or not we should save the tokenizer from the base model. Only
42+
used if the export_path is set.
43+
44+
Returns:
45+
transformers model
46+
Merged model created from the checkpoint / base model.
47+
"""
48+
if isinstance(checkpoint_models, str):
49+
checkpoint_models = [checkpoint_models]
50+
51+
if base_model is None:
52+
base_model = fetch_base_model_from_checkpoint(checkpoint_models)
53+
model = AutoModelForCausalLM.from_pretrained(base_model)
54+
55+
# Merge each of the lora adapter models into the base model with equal weights
56+
for checkpoint_model in tqdm(checkpoint_models):
57+
model = PeftModel.from_pretrained(model, checkpoint_model)
58+
model = model.merge_and_unload()
59+
60+
if export_path is not None:
61+
model.save_pretrained(export_path)
62+
# Export the tokenizer into the merged model dir
63+
if save_tokenizer:
64+
tokenizer = AutoTokenizer.from_pretrained(base_model)
65+
tokenizer.save_pretrained(export_path)
66+
return model
67+
68+
69+
def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
70+
"""Inspects the checkpoint model, locates the adapter config, and grabs the
71+
base_model_name_or_path.
72+
73+
Args:
74+
checkpoint_model: str
75+
Checkpoint model containing the adapter config, which specifies the base model.
76+
77+
Returns:
78+
str
79+
base_model_name_or_path specified in the adapter config of the tuned peft model.
80+
"""
81+
adapter_config = os.path.join(checkpoint_model, "adapter_config.json")
82+
if not os.path.isfile(adapter_config):
83+
raise FileNotFoundError("Unable to locate adapter config to infer base model!")
84+
85+
with open(adapter_config, "r") as cfg:
86+
adapter_dict = json.load(cfg)
87+
if "base_model_name_or_path" not in adapter_dict:
88+
raise KeyError(
89+
"Base model adapter config exists, but has no base_model_name_or_path!"
90+
)
91+
return adapter_dict["base_model_name_or_path"]

0 commit comments

Comments
 (0)