@@ -256,6 +256,35 @@ def run_canonical_optimizations(self):
256256 assert res .graph_module is not None , "Pass returned None"
257257 self .pre_autograd_graph_module = res .graph_module
258258
259+ def _check_calibration_prefix_options (self ) -> None :
260+ if (
261+ not self .use_kv_cache
262+ and not self .enable_dynamic_shape
263+ and not self .generate_full_logits
264+ ):
265+ raise ValueError (
266+ "Static non-KV calibration with padded prefixes requires "
267+ "generate_full_logits so calibration can sample the last "
268+ "non-pad token position."
269+ )
270+
271+ def _prepare_calibration_prefix (
272+ self , token_list : List [int ], pos : int , max_len : int , pad_token : int
273+ ) -> Tuple [torch .Tensor , int ]:
274+ prefix_tokens = list (token_list [: pos + 1 ])
275+ logits_token_pos = min (len (prefix_tokens ), max_len ) - 1
276+
277+ if self .enable_dynamic_shape :
278+ prefix_tokens = prefix_tokens [:max_len ]
279+ elif len (prefix_tokens ) < max_len :
280+ prefix_tokens .extend ([pad_token ] * (max_len - len (prefix_tokens )))
281+ else :
282+ prefix_tokens = prefix_tokens [:max_len ]
283+
284+ input_dtype = self .example_inputs [0 ].dtype
285+ prefix = torch .tensor (prefix_tokens , dtype = input_dtype ).unsqueeze (0 )
286+ return prefix , logits_token_pos
287+
259288 def pt2e_calibrate (
260289 self ,
261290 prepared_module ,
@@ -266,39 +295,41 @@ def pt2e_calibrate(
266295 tokenizer_path ,
267296 ):
268297 logging .info ("Run calibration..." )
269- try :
270- from executorch .examples .models .llama .eval_llama_lib import (
271- GraphModuleEvalWrapper ,
272- )
273- from lm_eval .evaluator import simple_evaluate
274- except ImportError :
275- raise ImportError (
276- "Please install the llm eval dependency via examples/models/llama/install_requirements.sh"
277- )
278-
298+ self ._check_calibration_prefix_options ()
279299 tokenizer = get_tokenizer (tokenizer_path )
280300
281301 def calibrate_template (
282302 module : torch .fx .GraphModule , tokenizer , prompts : str , max_len : int
283303 ):
284304 # TODO: change criteria & support batch inputs if necessary
285- pos = torch . tensor ( 0 , dtype = torch . int64 )
305+ pos = 0
286306 token_list = tokenizer .encode (prompts , bos = True , eos = False )
287307
308+ pad_token = getattr (tokenizer , "pad_id" , tokenizer .eos_id )
309+
288310 with torch .no_grad ():
289311 while token_list [- 1 ] != tokenizer .eos_id and pos < max_len :
290- logits = module (
291- torch .full ((1 , 1 ), token_list [pos ]),
292- {"input_pos" : torch .tensor ((pos ,))},
293- )
312+ logits_token_pos = - 1
313+ if self .use_kv_cache :
314+ logits = module (
315+ torch .full ((1 , 1 ), token_list [pos ]),
316+ {"input_pos" : torch .tensor ((pos ,))},
317+ )
318+ else :
319+ prefix , logits_token_pos = self ._prepare_calibration_prefix (
320+ token_list , pos , max_len , pad_token
321+ )
322+ logits = module (prefix )
323+
294324 pos += 1
295325 if pos >= len (token_list ):
296326 if self .generate_full_logits :
297- token_list . append (
298- torch . argmax ( logits [:, - 1 ], dim = - 1 ). item ()
299- )
327+ next_token = torch . argmax (
328+ logits [:, logits_token_pos ], dim = - 1
329+ ). item ()
300330 else :
301- token_list .append (torch .argmax (logits [:], dim = - 1 ).item ())
331+ next_token = torch .argmax (logits [:], dim = - 1 ).item ()
332+ token_list .append (next_token )
302333
303334 calibrate_template (
304335 module = prepared_module ,
@@ -307,26 +338,41 @@ def calibrate_template(
307338 max_len = calibration_seq_length ,
308339 )
309340
310- eval_wrapper = GraphModuleEvalWrapper (
311- model = prepared_module ,
312- tokenizer = tokenizer ,
313- max_seq_length = calibration_seq_length ,
314- use_kv_cache = self .use_kv_cache ,
315- generate_full_logits = self .generate_full_logits ,
316- enable_dynamic_shape = self .enable_dynamic_shape ,
317- )
341+ if calibration_tasks :
342+ try :
343+ from executorch .examples .models .llama .eval_llama_lib import (
344+ GraphModuleEvalWrapper ,
345+ )
346+ from lm_eval .evaluator import simple_evaluate
347+ except ImportError :
348+ raise ImportError (
349+ "Please install the llm eval dependency via examples/models/llama/install_requirements.sh"
350+ )
318351
319- # Evaluate the model
320- with torch .no_grad ():
321- eval_results = simple_evaluate (
322- model = eval_wrapper ,
323- tasks = calibration_tasks ,
324- limit = calibration_limit ,
352+ eval_wrapper = GraphModuleEvalWrapper (
353+ model = prepared_module ,
354+ tokenizer = tokenizer ,
355+ max_seq_length = calibration_seq_length ,
356+ use_kv_cache = self .use_kv_cache ,
357+ generate_full_logits = self .generate_full_logits ,
358+ enable_dynamic_shape = self .enable_dynamic_shape ,
359+ # The exported graph can contain ops like aten.full.default
360+ # without explicit device, which default to CPU and can
361+ # trigger device-mismatch errors when lm_eval runs on CUDA.
362+ # Calibrate on CPU for stability.
363+ device = "cpu" ,
325364 )
326365
327- for task , res in eval_results ["results" ].items ():
328- print (f"{ task } : { res } " )
329- logging .info ("Calibration finish..." )
366+ with torch .no_grad ():
367+ eval_results = simple_evaluate (
368+ model = eval_wrapper ,
369+ tasks = calibration_tasks ,
370+ limit = calibration_limit ,
371+ )
372+
373+ for task , res in eval_results ["results" ].items ():
374+ print (f"{ task } : { res } " )
375+ logging .info ("Calibration finish..." )
330376
331377 def pt2e_quantize (self , quantizers : Optional [List [Quantizer ]]) -> "LLMEdgeManager" :
332378 """
@@ -351,18 +397,19 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
351397 assert (
352398 self .pre_autograd_graph_module is not None
353399 ), "Please run export() first"
400+ if self .calibration_tasks and self .calibration_limit is None :
401+ logging .warning (
402+ "calibration_tasks provided without calibration_limit; "
403+ "lm-eval will run the full task dataset during "
404+ "calibration."
405+ )
354406 m = prepare_pt2e (
355407 self .pre_autograd_graph_module , # pyre-ignore[6]
356408 composed_quantizer ,
357409 )
358- logging .info (
359- f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
360- )
361410 # Calibrate
362411 if (
363- self .calibration_tasks is not None
364- and self .calibration_limit is not None
365- and self .calibration_seq_length is not None
412+ self .calibration_seq_length is not None
366413 and self .calibration_data is not None
367414 and self .tokenizer_path is not None
368415 ):
0 commit comments