Skip to content

Commit fb3f2e7

Browse files
formatting
1 parent e36f2ed commit fb3f2e7

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

tuning/utils/merge_model_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
# Standard
2+
from typing import Union
13
import argparse
24
import json
35
import os
4-
from typing import Union
5-
from tqdm import tqdm
6+
7+
# Third Party
68
from peft import PeftModel
9+
from tqdm import tqdm
710
from transformers import AutoModelForCausalLM, AutoTokenizer
811

12+
913
def create_merged_model(
1014
checkpoint_models: Union[str, list[str]],
11-
export_path: str=None,
12-
base_model: str=None,
13-
save_tokenizer: bool=True
15+
export_path: str = None,
16+
base_model: str = None,
17+
save_tokenizer: bool = True,
1418
):
1519
"""Given a base model & checkpoint model(s) which were tuned with lora, load into memory
1620
& create a merged model. If an export path is specified, write it to disk. If multiple
@@ -69,7 +73,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
6973
Args:
7074
checkpoint_model: str
7175
Checkpoint model containing the adapter config, which specifies the base model.
72-
76+
7377
Returns:
7478
str
7579
base_model_name_or_path specified in the adapter config of the tuned peft model.
@@ -81,5 +85,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
8185
with open(adapter_config, "r") as cfg:
8286
adapter_dict = json.load(cfg)
8387
if "base_model_name_or_path" not in adapter_dict:
84-
raise KeyError("Base model adapter config exists, but has no base_model_name_or_path!")
88+
raise KeyError(
89+
"Base model adapter config exists, but has no base_model_name_or_path!"
90+
)
8591
return adapter_dict["base_model_name_or_path"]

0 commit comments

Comments
 (0)