1+ # Standard
2+ from typing import Union
13import argparse
24import json
35import os
4- from typing import Union
5- from tqdm import tqdm
6+
7+ # Third Party
68from peft import PeftModel
9+ from tqdm import tqdm
710from transformers import AutoModelForCausalLM , AutoTokenizer
811
12+
913def create_merged_model (
1014 checkpoint_models : Union [str , list [str ]],
11- export_path : str = None ,
12- base_model : str = None ,
13- save_tokenizer : bool = True
15+ export_path : str = None ,
16+ base_model : str = None ,
17+ save_tokenizer : bool = True ,
1418):
1519 """Given a base model & checkpoint model(s) which were tuned with lora, load into memory
1620 & create a merged model. If an export path is specified, write it to disk. If multiple
@@ -69,7 +73,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
6973 Args:
7074 checkpoint_model: str
7175 Checkpoint model containing the adapter config, which specifies the base model.
72-
76+
7377 Returns:
7478 str
7579 base_model_name_or_path specified in the adapter config of the tuned peft model.
@@ -81,5 +85,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
8185 with open (adapter_config , "r" ) as cfg :
8286 adapter_dict = json .load (cfg )
8387 if "base_model_name_or_path" not in adapter_dict :
84- raise KeyError ("Base model adapter config exists, but has no base_model_name_or_path!" )
88+ raise KeyError (
89+ "Base model adapter config exists, but has no base_model_name_or_path!"
90+ )
8591 return adapter_dict ["base_model_name_or_path" ]
0 commit comments