Skip to content

Commit 1941d07

Browse files
Qualcomm AI Engine Direct - PTQ Mix precision guidance for LLMs (pytorch#18969)
1 parent cd81156 commit 1941d07

9 files changed

Lines changed: 1099 additions & 11 deletions

File tree

backends/qualcomm/quantizer/quant_recipe.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import re
9+
import textwrap
910
from abc import ABC, abstractmethod
1011
from enum import IntEnum, unique
1112
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
@@ -424,3 +425,76 @@ def summary(self, max_rows: int = -1):
424425
rows.append(["..."] * len(headers))
425426

426427
return tabulate(rows, headers=headers, tablefmt="grid")
428+
429+
def to_source(self) -> str:
430+
"""
431+
Serializes this QuantRecipe into a Python source string at zero indentation.
432+
"""
433+
434+
def _dtype(d: QuantDtype) -> str:
435+
return f"QuantDtype.{d.name}"
436+
437+
def _granularity(g: QuantGranularity) -> str:
438+
return f"QuantGranularity.{g.name}"
439+
440+
def _comments(note: str) -> str:
441+
lines = note.strip().splitlines() if note.strip() else []
442+
return "".join(f"# {ln}\n" for ln in lines)
443+
444+
indent = "\t"
445+
446+
def _args(*lines: str) -> str:
447+
return "".join(f"{indent}{ln},\n" for ln in lines)
448+
449+
strategy_blocks: List[str] = []
450+
for strategy in self._strategies:
451+
extra_kwargs_flag = (
452+
[f"extra_kwargs={strategy.extra_kwargs!r}"]
453+
if strategy.extra_kwargs
454+
else []
455+
)
456+
if isinstance(strategy, ByNodeTarget):
457+
targets_repr = ", ".join(
458+
f"torch.ops.{t._overloadpacket._qualified_op_name.replace('::', '.')}.{t._overloadname}"
459+
for t in sorted(strategy.targets, key=lambda t: str(t))
460+
)
461+
args = _args(
462+
f"{{{targets_repr}}}",
463+
_dtype(strategy.quant_dtype),
464+
str(strategy.is_qat),
465+
"act_observer=MinMaxObserver",
466+
f"granularity={_granularity(strategy.granularity)}",
467+
*extra_kwargs_flag,
468+
f"act_symmetric={strategy.act_symmetric}",
469+
f"note={strategy.note!r}",
470+
)
471+
call = f".add_node_target(\n{args})"
472+
elif isinstance(strategy, ByNameRegex):
473+
patterns_repr = ", ".join(f'r"{p}"' for p in sorted(strategy.patterns))
474+
args = _args(
475+
f"{{{patterns_repr}}}",
476+
_dtype(strategy.quant_dtype),
477+
str(strategy.is_qat),
478+
"act_observer=MinMaxObserver",
479+
f"granularity={_granularity(strategy.granularity)}",
480+
*extra_kwargs_flag,
481+
f"act_symmetric={strategy.act_symmetric}",
482+
f"note={strategy.note!r}",
483+
)
484+
call = f".add_regex(\n{args})"
485+
else:
486+
continue
487+
488+
strategy_blocks.append(_comments(strategy.note) + call)
489+
490+
header_args = _args(
491+
"self.default_quant_dtype",
492+
str(self._default_is_qat),
493+
"act_observer=MinMaxObserver",
494+
f"granularity={_granularity(self._default_granularity)}",
495+
"verbose=verbose",
496+
)
497+
header = f"QuantRecipe(\n{header_args})"
498+
chained = "\n".join(strategy_blocks)
499+
body = header + "\n" + chained
500+
return "(\n" + textwrap.indent(body, indent) + "\n)"

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9057,6 +9057,66 @@ def test_intermediate_debugger(self):
90579057
f"CSV valid count: {csv_valid_count}. SVG valid count: {svg_valid_count}"
90589058
)
90599059

9060+
def test_analyzer_to_file_generation(self):
9061+
"""
9062+
End-to-end test for PerLayerSqnrAnalyzer → SqnrReport → file generation.
9063+
"""
9064+
from executorch.examples.qualcomm.oss_scripts.llama.mix_precision_analyzer import (
9065+
PerLayerSqnrAnalyzer,
9066+
save_suggest_recipes,
9067+
)
9068+
9069+
module = SimpleModel() # noqa: F405
9070+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
9071+
fp32_gm = torch.export.export(module, sample_input, strict=True).module()
9072+
qdq_gm = self.get_qdq_module(
9073+
module, sample_input, quant_dtype=QuantDtype.use_8a4w
9074+
)
9075+
9076+
report = PerLayerSqnrAnalyzer(
9077+
model_name="simple_conv",
9078+
num_layers=4,
9079+
fp32_gm=fp32_gm,
9080+
qdq_gm=qdq_gm,
9081+
).analyze([sample_input], num_sharding=4)
9082+
9083+
overrides = report.suggest_recipe_overrides(sqnr_threshold=22.0)
9084+
9085+
with tempfile.TemporaryDirectory() as tmp_dir:
9086+
report.save_analysis_summary(output_dir=tmp_dir)
9087+
save_suggest_recipes(report, overrides, output_dir=tmp_dir)
9088+
9089+
# --- save_analysis_summary csv file ---
9090+
with open(f"{tmp_dir}/simple_conv_quantization_error.csv") as f:
9091+
csv_content = f.read()
9092+
rows = list(csv.reader(csv_content.splitlines()))
9093+
self.assertEqual(len(rows), 5) # 1 header + 4 group rows
9094+
self.assertEqual(
9095+
rows[0],
9096+
[
9097+
"group_name",
9098+
"avg_sqnr",
9099+
"median_sqnr",
9100+
"min_sqnr",
9101+
"max_sqnr",
9102+
"count",
9103+
],
9104+
)
9105+
print(f"Sensitivity analysis:\n{csv_content}")
9106+
9107+
# --- save_suggest_recipes .py file (only written when sensitive layers exist) ---
9108+
if overrides:
9109+
with open(f"{tmp_dir}/simple_conv_suggest_recipe.py") as f:
9110+
py_content = f.read()
9111+
# generated file must be valid Python
9112+
try:
9113+
compile(py_content, "simple_conv_suggest_recipe.py", "exec")
9114+
except SyntaxError as e:
9115+
self.fail(
9116+
f"Generated recipe file has syntax error: {e}\n{py_content}"
9117+
)
9118+
self.assertIn("HOW TO USE THESE RECIPES", py_content)
9119+
90609120

90619121
def setup_environment():
90629122
parser = setup_common_args_and_variables()

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,20 @@ Example:
467467
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods sqnr_eval
468468
```
469469

470+
#### Quantization Guidance
471+
472+
To automatically identify sensitive layers and generate a mixed-precision recipe suggestion, add the `--quant_recipe_suggestion` flag. During calibration, the analyzer compares FP32 and QDQ intermediate outputs layer-by-layer using SQNR, then writes two files to the working directory:
473+
474+
- `{model_name}_quantization_error.csv` — per-group SQNR statistics sorted by sensitivity (most sensitive first)
475+
- `{model_name}_suggest_recipe.py` — ready-to-use `StaticLLMQuantRecipe` subclasses optimized to apply higher-precision quantization to the most sensitive groups.
476+
477+
Example:
478+
```bash
479+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --tasks wikitext --limit 1 --quant_recipe_suggestion --compile_only
480+
```
481+
482+
After the run, pick one of the generated classes from `qwen3-1_7b_suggest_recipe.py` as your new recipe. For a full walkthrough, see [quantization_guidance.md](quantization_guidance.md).
483+
470484
#### Use attention sink for multi-turn conversations
471485
Attention sink is a way to evict cache when maximum context length be reached.
472486
There are two mainly concept for attention sink:

examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def __init__(
403403
self.max_seq_length = pte_max_context_len
404404

405405
def run(self, prompt):
406-
golden_logits = INFERENCE_REGISTRY[True](
406+
golden_logits, _ = INFERENCE_REGISTRY[True](
407407
get_example_inputs=self.get_example_inputs,
408408
prompt=prompt,
409409
module=self.source_model,

0 commit comments

Comments
 (0)