88
99If these things change in the future, we should consider breaking it up.
1010"""
11+ # Standard
1112import argparse
1213import json
1314import os
15+
16+ # Third Party
1417from peft import AutoPeftModelForCausalLM
15- import torch
1618from tqdm import tqdm
1719from transformers import AutoTokenizer
20+ import torch
1821
1922
2023### Utilities
@@ -30,10 +33,13 @@ class AdapterConfigPatcher:
3033 # When loaded in this block, the config's base_model_name_or_path is "foo"
3134 peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
3235 """
36+
3337 def __init__ (self , checkpoint_path : str , overrides : dict ):
3438 self .checkpoint_path = checkpoint_path
3539 self .overrides = overrides
36- self .config_path = AdapterConfigPatcher ._locate_adapter_config (self .checkpoint_path )
40+ self .config_path = AdapterConfigPatcher ._locate_adapter_config (
41+ self .checkpoint_path
42+ )
3743 # Values that we will patch later on
3844 self .patched_values = {}
3945
@@ -58,7 +64,7 @@ def _locate_adapter_config(checkpoint_path: str) -> str:
5864 def _apply_config_changes (self , overrides : dict ) -> dict :
5965 """Applies a patch to a config with some override dict, returning the values
6066 that we patched over so that they may be restored later.
61-
67+
6268 Args:
6369 overrides: dict
6470 Overrides to write into the adapter_config.json. Currently, we
@@ -99,7 +105,9 @@ def _get_old_config_values(adapter_config: dict, overrides: dict) -> dict:
99105 # For now, we only expect to patch the base model; this may change in the future,
100106 # but ensure that anything we are patching is defined in the original config
101107 if not set (overrides .keys ()).issubset (set (adapter_config .keys ())):
102- raise KeyError ("Adapter config overrides must be set in the config being patched" )
108+ raise KeyError (
109+ "Adapter config overrides must be set in the config being patched"
110+ )
103111 return {key : adapter_config [key ] for key in overrides }
104112
105113 def __enter__ (self ):
@@ -119,7 +127,9 @@ def __init__(self, model, tokenizer, device):
119127 self .device = device
120128
121129 @classmethod
122- def load (cls , checkpoint_path : str , base_model_name_or_path : str = None ) -> "TunedCausalLM" :
130+ def load (
131+ cls , checkpoint_path : str , base_model_name_or_path : str = None
132+ ) -> "TunedCausalLM" :
123133 """Loads an instance of this model.
124134
125135 Args:
@@ -138,7 +148,11 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned
138148 TunedCausalLM
139149 An instance of this class on which we can run inference.
140150 """
141- overrides = {"base_model_name_or_path" : base_model_name_or_path } if base_model_name_or_path is not None else {}
151+ overrides = (
152+ {"base_model_name_or_path" : base_model_name_or_path }
153+ if base_model_name_or_path is not None
154+ else {}
155+ )
142156 tokenizer = AutoTokenizer .from_pretrained (checkpoint_path )
143157 # Apply the configs to the adapter config of this model; if no overrides
144158 # are provided, then the context manager doesn't have any effect.
@@ -153,7 +167,6 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned
153167 peft_model .to (device )
154168 return cls (peft_model , tokenizer , device )
155169
156-
157170 def run (self , text : str , * , max_new_tokens : int ) -> str :
158171 """Runs inference on an instance of this model.
159172
@@ -165,13 +178,17 @@ def run(self, text: str, *, max_new_tokens: int) -> str:
165178
166179 Returns:
167180 str
168- Text generation result.
181+ Text generation result.
169182 """
170183 tok_res = self .tokenizer (text , return_tensors = "pt" )
171184 input_ids = tok_res .input_ids .to (self .device )
172185
173- peft_outputs = self .peft_model .generate (input_ids = input_ids , max_new_tokens = max_new_tokens )
174- decoded_result = self .tokenizer .batch_decode (peft_outputs , skip_special_tokens = False )[0 ]
186+ peft_outputs = self .peft_model .generate (
187+ input_ids = input_ids , max_new_tokens = max_new_tokens
188+ )
189+ decoded_result = self .tokenizer .batch_decode (
190+ peft_outputs , skip_special_tokens = False
191+ )[0 ]
175192 return decoded_result
176193
177194
@@ -180,7 +197,9 @@ def main():
180197 parser = argparse .ArgumentParser (
181198 description = "Loads a tuned model and runs an inference call(s) through it"
182199 )
183- parser .add_argument ("--model" , help = "Path to tuned model to be loaded" , required = True )
200+ parser .add_argument (
201+ "--model" , help = "Path to tuned model to be loaded" , required = True
202+ )
184203 parser .add_argument (
185204 "--out_file" ,
186205 help = "JSON file to write results to" ,
@@ -189,7 +208,7 @@ def main():
189208 parser .add_argument (
190209 "--base_model_name_or_path" ,
191210 help = "Override for base model to be used [default: value in model adapter_config.json]" ,
192- default = None
211+ default = None ,
193212 )
194213 parser .add_argument (
195214 "--max_new_tokens" ,
@@ -199,7 +218,10 @@ def main():
199218 )
200219 group = parser .add_mutually_exclusive_group (required = True )
201220 group .add_argument ("--text" , help = "Text to run inference on" )
202- group .add_argument ("--text_file" , help = "File to be processed where each line is a text to run inference on" )
221+ group .add_argument (
222+ "--text_file" ,
223+ help = "File to be processed where each line is a text to run inference on" ,
224+ )
203225 args = parser .parse_args ()
204226 # If we passed a file, check if it exists before doing anything else
205227 if args .text_file and not os .path .isfile (args .text_file ):
@@ -220,7 +242,10 @@ def main():
220242
221243 # TODO: we should add batch inference support
222244 results = [
223- {"input" : text , "output" : loaded_model .run (text , max_new_tokens = args .max_new_tokens )}
245+ {
246+ "input" : text ,
247+ "output" : loaded_model .run (text , max_new_tokens = args .max_new_tokens ),
248+ }
224249 for text in tqdm (texts )
225250 ]
226251
@@ -230,5 +255,6 @@ def main():
230255
231256 print (f"Exported results to: { args .out_file } " )
232257
258+
233259if __name__ == "__main__" :
234260 main ()
0 commit comments