Skip to content

Commit c3f07e0

Browse files
Qualcomm AI Engine Direct - Decouple quantization and compile graphs for
faster VLM/LLM PTQ Summary: - Calibrate decoder using prefill stage only (full chunk input_ids) - Remove need for AR-N calibration loops - Significantly reduce calibration overhead
1 parent e4ede92 commit c3f07e0

4 files changed

Lines changed: 435 additions & 171 deletions

File tree

examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def __init__(
424424
self.max_seq_length = pte_max_context_len
425425

426426
def run(self, prompt):
427-
golden_logits, _ = INFERENCE_REGISTRY[True](
427+
result = INFERENCE_REGISTRY[True](
428428
get_example_inputs=self.get_example_inputs,
429429
prompt=prompt,
430430
module=self.source_model,
@@ -433,6 +433,7 @@ def run(self, prompt):
433433
use_i64_token=self.args.embedding_quantize is not None,
434434
collect_logits=True,
435435
)
436+
golden_logits = result.logits
436437

437438
input_file_name = f"{self.args.artifact}/input_tokens.raw"
438439

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ class DecoderInputs:
7777
embedding: Optional[torch.Tensor] = None
7878

7979

80+
@dataclass
81+
class DecoderOutputs:
82+
logits: Optional[torch.Tensor] = None
83+
token_list: Optional[List[int]] = None
84+
input_samples: Optional[List] = None
85+
86+
8087
class GraphModuleCalibrationWrapper(EagerEvalWrapper):
8188
"""
8289
A wrapper class for calibration
@@ -94,6 +101,7 @@ def __init__( # noqa: C901
94101
get_example_inputs: Callable,
95102
use_i64_token: bool,
96103
seq_mse_candidates: int,
104+
collect_input_samples: bool = False,
97105
):
98106
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
99107
assert max_seq_length is not None, "max_seq_length must be provided"
@@ -108,30 +116,32 @@ def __init__( # noqa: C901
108116
self.use_i64_token = use_i64_token
109117
self.seq_mse_candidates = seq_mse_candidates
110118
self._input_samples = None
119+
self.collect_input_samples = collect_input_samples
111120

112121
def get_input_samples(self):
113122
return self._input_samples
114123

115124
def _model_call(self, inps):
116-
all_logits = None
117125
kwargs = {}
118126
if self._use_kv_cache:
119127
kwargs["ar_len"] = self.ar_len
120128
kwargs["seq_mse_candidates"] = self.seq_mse_candidates
121129

122-
all_logits, self._input_samples = INFERENCE_REGISTRY[self._use_kv_cache](
130+
result = INFERENCE_REGISTRY[self._use_kv_cache](
123131
self.get_example_inputs,
124132
inps,
125133
self._model,
126134
self._tokenizer,
127135
max_seq_len=self.max_seq_length,
128136
use_i64_token=self.use_i64_token,
129137
collect_logits=True,
138+
collect_input_samples=self.collect_input_samples,
130139
**kwargs,
131140
)
141+
self._input_samples = result.input_samples
132142
# one shot is enough for seq mse
133143
self.seq_mse_candidates = 0
134-
return all_logits
144+
return result.logits
135145

136146

137147
class LookaheadDecoder:
@@ -731,7 +741,8 @@ def kv_inference( # noqa: C901
731741
collect_logits=False,
732742
seq_mse_candidates=0,
733743
lookahead_config=None,
734-
):
744+
collect_input_samples=False,
745+
) -> DecoderOutputs:
735746
input_samples = [] # Record input sample for quantization error analysis
736747
is_multimodal = all(
737748
[
@@ -818,6 +829,7 @@ def kv_inference( # noqa: C901
818829

819830
# record total input tokens and generated tokens
820831
total_token_list = prompt_token_list
832+
last_token_in_prompt = prompt_token_list[-1] if len(prompt_token_list) > 0 else None
821833

822834
# 3. prepare decoder inputs
823835
inputs = DecoderInputs(
@@ -845,28 +857,33 @@ def kv_inference( # noqa: C901
845857

846858
# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
847859
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
848-
generate_input_sample = _generate(
849-
inputs,
850-
cur_pos,
851-
module,
852-
tokenizer,
853-
tok_embedding,
854-
ar_len,
855-
max_seq_len,
856-
k_caches,
857-
v_caches,
858-
total_token_list,
859-
lookahead_config,
860-
)
861-
if generate_input_sample is not None:
862-
input_samples.append(generate_input_sample)
863-
else:
864-
input_samples.append(prefill_input_sample)
860+
generate_input_sample = None
861+
if last_token_in_prompt != tokenizer.eos_id:
862+
generate_input_sample = _generate(
863+
inputs,
864+
cur_pos,
865+
module,
866+
tokenizer,
867+
tok_embedding,
868+
ar_len,
869+
max_seq_len,
870+
k_caches,
871+
v_caches,
872+
total_token_list,
873+
lookahead_config,
874+
)
875+
876+
if collect_input_samples:
877+
input_samples.append(generate_input_sample or prefill_input_sample)
865878

866879
logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
867880
if collect_logits:
868881
result_logits = torch.cat(result_logits, dim=1)
869-
return result_logits, input_samples
882+
return DecoderOutputs(
883+
logits=result_logits if collect_logits else None,
884+
token_list=total_token_list,
885+
input_samples=input_samples if collect_input_samples else None,
886+
)
870887

871888

872889
@register_inference(use_kv_cache=False)
@@ -882,7 +899,8 @@ def prefill_inference(
882899
max_seq_len=512,
883900
use_i64_token=False,
884901
collect_logits=False,
885-
):
902+
collect_input_samples=False,
903+
) -> DecoderOutputs:
886904
input_samples = None # Record input sample for quantization error analysis
887905
is_multimodal = all(
888906
[
@@ -950,7 +968,11 @@ def prefill_inference(
950968
pos += 1
951969
if isinstance(prompt, str):
952970
logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}")
953-
return result_logits, [input_samples]
971+
return DecoderOutputs(
972+
logits=result_logits if collect_logits else None,
973+
token_list=token_list,
974+
input_samples=[input_samples] if collect_input_samples else None,
975+
)
954976

955977

956978
def graph_module_inference(
@@ -972,7 +994,8 @@ def graph_module_inference(
972994
event_name: Optional[str] = None,
973995
seq_mse_candidates: int = 0,
974996
lookahead_config: Optional[Tuple[int]] = None,
975-
):
997+
collect_input_samples: bool = False,
998+
) -> DecoderOutputs:
976999
"""
9771000
This function supports model execution from static nn.Module decoder model
9781001
all the way to edge program.
@@ -988,7 +1011,7 @@ def graph_module_inference(
9881011
kwargs["ar_len"] = ar_len
9891012
kwargs["lookahead_config"] = lookahead_config
9901013

991-
_, input_samples = INFERENCE_REGISTRY[use_kv_cache](
1014+
result = INFERENCE_REGISTRY[use_kv_cache](
9921015
get_example_inputs,
9931016
prompt,
9941017
module,
@@ -1000,10 +1023,11 @@ def graph_module_inference(
10001023
max_seq_len=max_seq_len,
10011024
use_i64_token=use_i64_token,
10021025
collect_logits=False,
1026+
collect_input_samples=collect_input_samples,
10031027
**kwargs,
10041028
)
10051029
logging.info(f"Prompt summary for {event_name}")
1006-
return input_samples
1030+
return result
10071031
else:
10081032
calibration_wrapper = GraphModuleCalibrationWrapper(
10091033
model=module,
@@ -1014,6 +1038,7 @@ def graph_module_inference(
10141038
get_example_inputs=get_example_inputs,
10151039
use_i64_token=use_i64_token,
10161040
seq_mse_candidates=seq_mse_candidates,
1041+
collect_input_samples=collect_input_samples,
10171042
)
10181043
with torch.no_grad():
10191044
eval_results = simple_evaluate(
@@ -1026,4 +1051,4 @@ def graph_module_inference(
10261051
for task, res in eval_results["results"].items():
10271052
logging.info(f"{task}: {res}")
10281053

1029-
return calibration_wrapper.get_input_samples()
1054+
return DecoderOutputs(input_samples=calibration_wrapper.get_input_samples())

examples/qualcomm/oss_scripts/llama/wrappers/base_component.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
class Mode(Enum):
4141
PREFILL = 1
4242
DECODE = 2
43+
CALIBRATE = 3
4344

4445

4546
def log_info(func):
@@ -83,7 +84,7 @@ def process_model_args(
8384
model_args: ModelArgs object to be modified.
8485
quant_recipe: Quantization recipe to be used.
8586
config: LLMModelConfig object to be used.
86-
mode: Mode of operation (PREFILL or DECODE).
87+
mode: Mode of operation (PREFILL, DECODE, or CALIBRATE).
8788
"""
8889
# TODO: support batch inputs if necessary
8990
if mode == Mode.DECODE:
@@ -95,13 +96,19 @@ def process_model_args(
9596
if control_args.model_mode == "lookahead"
9697
else 1
9798
)
98-
else:
99+
elif mode == Mode.PREFILL:
99100
ar_len = control_args.prefill_ar_len
101+
elif mode == Mode.CALIBRATE:
102+
ar_len = control_args.max_context_len
103+
else:
104+
raise ValueError(f"Unsupported mode: {mode}")
100105

101106
model_args.max_batch_size = 1
102107
model_args.max_seq_len = control_args.max_seq_len
103108
model_args.max_context_len = control_args.max_context_len
104-
model_args.use_kv_cache = control_args.max_context_len != ar_len
109+
model_args.use_kv_cache = (
110+
control_args.max_context_len != ar_len or mode == Mode.CALIBRATE
111+
)
105112
model_args.enable_r3 = config.r3
106113
model_args.ar_len = ar_len
107114
model_args.kv_io_bit_width = quant_recipe.get_kv_io_bit_width()

0 commit comments

Comments
 (0)