@@ -77,6 +77,13 @@ class DecoderInputs:
7777 embedding : Optional [torch .Tensor ] = None
7878
7979
80+ @dataclass
81+ class DecoderOutputs :
82+ logits : Optional [torch .Tensor ] = None
83+ token_list : Optional [List [int ]] = None
84+ input_samples : Optional [List ] = None
85+
86+
8087class GraphModuleCalibrationWrapper (EagerEvalWrapper ):
8188 """
8289 A wrapper class for calibration
@@ -94,6 +101,7 @@ def __init__( # noqa: C901
94101 get_example_inputs : Callable ,
95102 use_i64_token : bool ,
96103 seq_mse_candidates : int ,
104+ collect_input_samples : bool = False ,
97105 ):
98106 # n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
99107 assert max_seq_length is not None , "max_seq_length must be provided"
@@ -108,30 +116,32 @@ def __init__( # noqa: C901
108116 self .use_i64_token = use_i64_token
109117 self .seq_mse_candidates = seq_mse_candidates
110118 self ._input_samples = None
119+ self .collect_input_samples = collect_input_samples
111120
112121 def get_input_samples (self ):
113122 return self ._input_samples
114123
115124 def _model_call (self , inps ):
116- all_logits = None
117125 kwargs = {}
118126 if self ._use_kv_cache :
119127 kwargs ["ar_len" ] = self .ar_len
120128 kwargs ["seq_mse_candidates" ] = self .seq_mse_candidates
121129
122- all_logits , self . _input_samples = INFERENCE_REGISTRY [self ._use_kv_cache ](
130+ result = INFERENCE_REGISTRY [self ._use_kv_cache ](
123131 self .get_example_inputs ,
124132 inps ,
125133 self ._model ,
126134 self ._tokenizer ,
127135 max_seq_len = self .max_seq_length ,
128136 use_i64_token = self .use_i64_token ,
129137 collect_logits = True ,
138+ collect_input_samples = self .collect_input_samples ,
130139 ** kwargs ,
131140 )
141+ self ._input_samples = result .input_samples
132142 # one shot is enough for seq mse
133143 self .seq_mse_candidates = 0
134- return all_logits
144+ return result . logits
135145
136146
137147class LookaheadDecoder :
@@ -731,7 +741,8 @@ def kv_inference( # noqa: C901
731741 collect_logits = False ,
732742 seq_mse_candidates = 0 ,
733743 lookahead_config = None ,
734- ):
744+ collect_input_samples = False ,
745+ ) -> DecoderOutputs :
735746 input_samples = [] # Record input sample for quantization error analysis
736747 is_multimodal = all (
737748 [
@@ -818,6 +829,7 @@ def kv_inference( # noqa: C901
818829
819830 # record total input tokens and generated tokens
820831 total_token_list = prompt_token_list
832+ last_token_in_prompt = prompt_token_list [- 1 ] if len (prompt_token_list ) > 0 else None
821833
822834 # 3. prepare decoder inputs
823835 inputs = DecoderInputs (
@@ -845,28 +857,33 @@ def kv_inference( # noqa: C901
845857
846858 # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
847859 # When run on wikitext for ppl evaluation, this while-loop is not expected to run.
848- generate_input_sample = _generate (
849- inputs ,
850- cur_pos ,
851- module ,
852- tokenizer ,
853- tok_embedding ,
854- ar_len ,
855- max_seq_len ,
856- k_caches ,
857- v_caches ,
858- total_token_list ,
859- lookahead_config ,
860- )
861- if generate_input_sample is not None :
862- input_samples .append (generate_input_sample )
863- else :
864- input_samples .append (prefill_input_sample )
860+ generate_input_sample = None
861+ if last_token_in_prompt != tokenizer .eos_id :
862+ generate_input_sample = _generate (
863+ inputs ,
864+ cur_pos ,
865+ module ,
866+ tokenizer ,
867+ tok_embedding ,
868+ ar_len ,
869+ max_seq_len ,
870+ k_caches ,
871+ v_caches ,
872+ total_token_list ,
873+ lookahead_config ,
874+ )
875+
876+ if collect_input_samples :
877+ input_samples .append (generate_input_sample or prefill_input_sample )
865878
866879 logging .info (f"kv inference result:\n { tokenizer .decode (total_token_list )} " )
867880 if collect_logits :
868881 result_logits = torch .cat (result_logits , dim = 1 )
869- return result_logits , input_samples
882+ return DecoderOutputs (
883+ logits = result_logits if collect_logits else None ,
884+ token_list = total_token_list ,
885+ input_samples = input_samples if collect_input_samples else None ,
886+ )
870887
871888
872889@register_inference (use_kv_cache = False )
@@ -882,7 +899,8 @@ def prefill_inference(
882899 max_seq_len = 512 ,
883900 use_i64_token = False ,
884901 collect_logits = False ,
885- ):
902+ collect_input_samples = False ,
903+ ) -> DecoderOutputs :
886904 input_samples = None # Record input sample for quantization error analysis
887905 is_multimodal = all (
888906 [
@@ -950,7 +968,11 @@ def prefill_inference(
950968 pos += 1
951969 if isinstance (prompt , str ):
952970 logging .info (f"prefill inference result:\n { tokenizer .decode (token_list )} " )
953- return result_logits , [input_samples ]
971+ return DecoderOutputs (
972+ logits = result_logits if collect_logits else None ,
973+ token_list = token_list ,
974+ input_samples = [input_samples ] if collect_input_samples else None ,
975+ )
954976
955977
956978def graph_module_inference (
@@ -972,7 +994,8 @@ def graph_module_inference(
972994 event_name : Optional [str ] = None ,
973995 seq_mse_candidates : int = 0 ,
974996 lookahead_config : Optional [Tuple [int ]] = None ,
975- ):
997+ collect_input_samples : bool = False ,
998+ ) -> DecoderOutputs :
976999 """
9771000 This function supports model execution from static nn.Module decoder model
9781001 all the way to edge program.
@@ -988,7 +1011,7 @@ def graph_module_inference(
9881011 kwargs ["ar_len" ] = ar_len
9891012 kwargs ["lookahead_config" ] = lookahead_config
9901013
991- _ , input_samples = INFERENCE_REGISTRY [use_kv_cache ](
1014+ result = INFERENCE_REGISTRY [use_kv_cache ](
9921015 get_example_inputs ,
9931016 prompt ,
9941017 module ,
@@ -1000,10 +1023,11 @@ def graph_module_inference(
10001023 max_seq_len = max_seq_len ,
10011024 use_i64_token = use_i64_token ,
10021025 collect_logits = False ,
1026+ collect_input_samples = collect_input_samples ,
10031027 ** kwargs ,
10041028 )
10051029 logging .info (f"Prompt summary for { event_name } " )
1006- return input_samples
1030+ return result
10071031 else :
10081032 calibration_wrapper = GraphModuleCalibrationWrapper (
10091033 model = module ,
@@ -1014,6 +1038,7 @@ def graph_module_inference(
10141038 get_example_inputs = get_example_inputs ,
10151039 use_i64_token = use_i64_token ,
10161040 seq_mse_candidates = seq_mse_candidates ,
1041+ collect_input_samples = collect_input_samples ,
10171042 )
10181043 with torch .no_grad ():
10191044 eval_results = simple_evaluate (
@@ -1026,4 +1051,4 @@ def graph_module_inference(
10261051 for task , res in eval_results ["results" ].items ():
10271052 logging .info (f"{ task } : { res } " )
10281053
1029- return calibration_wrapper .get_input_samples ()
1054+ return DecoderOutputs ( input_samples = calibration_wrapper .get_input_samples () )
0 commit comments