Skip to content

Commit 42a5446

Browse files
authored
Qualcomm AI Engine Direct - remove prefill calibration (#17805)
### Summary - calibrate kv text decoder only to reduce calibration time - deprecate outdated implementation & use deterministic example inputs for llm **Total Quantization Time** | Model | Before(s) | After(s) | Improvement | | :---: | :---: | :---: | :---: | | gemma-2b | 2203.399 | 999.512 | 54.64% | | gemma2-2b | 2177.285 | 1001.248 | 54.01% | | gemma3-1b | 1776.861 | 548.312 | 69.14% | | glm-1_5b | 1434.780 | 677.257 | 52.8% | | granite_3_3-2b | 59566.790 | 6165.443 | 89.65% | | llama3_2-1b | 4528.620 | 2953.233 | 34.79% | | llama3_2-3b | 5744.429 | 1652.157 | 71.24% | | phi_4_mini | 7005.601 | 2071.634 | 84.56% | | qwen2_5-0_5b | 480.508 | 372.076 | 22.57% | | qwen2_5-1_5b | 2064.333 | 899.164 | 56.44% | | qwen3-0_6b | 1673.150 | 1124.149 | 32.81% | | qwen3-1_7b | 3253.723 | 1148.511 | 64.7% | | smollm2_135m | 502.779 | 414.510 | 17.56% | | smollm3-3b | 4663.057 | 1613.516 | 65.4% | | smolvlm_500m_instruct | 288.246 | 170.829 | 40.73% | | internvl3_1b | 256.624 | 170.811 | 33.44% | ### Test plan `python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript / TestExampleMultimodalityScript`
1 parent 6c02866 commit 42a5446

5 files changed

Lines changed: 130 additions & 338 deletions

File tree

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
)
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from torch.fx import Node
24-
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
24+
from torchao.quantization.pt2e import MinMaxObserver
2525
from torchao.quantization.pt2e.quantizer import (
2626
annotate_input_qspec_map,
2727
annotate_output_qspec,
2828
QuantizationAnnotation,
29-
QuantizationSpec,
3029
SharedQuantizationSpec,
3130
)
3231

@@ -92,40 +91,6 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule):
9291
break
9392

9493

95-
def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
96-
for node in gm.graph.nodes:
97-
if node.op == "output":
98-
for index, prefill_output in enumerate(node.args[0]):
99-
kv_quant_attr = kv_quant_attrs[index]
100-
fixed_observer = FixedQParamsObserver.with_args(
101-
scale=kv_quant_attr[0],
102-
zero_point=kv_quant_attr[1],
103-
quant_min=kv_quant_attr[2],
104-
quant_max=kv_quant_attr[3],
105-
dtype=kv_quant_attr[4],
106-
qscheme=torch.torch.per_tensor_affine,
107-
)
108-
109-
fixed_output_spec = QuantizationSpec(
110-
quant_min=kv_quant_attr[2],
111-
quant_max=kv_quant_attr[3],
112-
dtype=kv_quant_attr[4],
113-
ch_axis=0,
114-
observer_or_fake_quant_ctr=fixed_observer,
115-
)
116-
117-
input_qspec_map = {}
118-
for input in prefill_output.args:
119-
if isinstance(input, Node):
120-
input_qspec_map[input] = fixed_output_spec
121-
122-
prefill_output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
123-
input_qspec_map=input_qspec_map,
124-
output_qspec=fixed_output_spec,
125-
_annotated=True,
126-
)
127-
128-
12994
def annotate_kv_8bit( # noqa: C901
13095
gm: torch.fx.GraphModule,
13196
is_qat=False,

examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp

Lines changed: 0 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -294,108 +294,6 @@ Error MultimodalRunner<T>::load() {
294294
cache_mode_,
295295
static_cast<int32_t>(dim)});
296296

297-
if (eval_mode_ == EvalMode::kLookaheadDecoding ||
298-
eval_mode_ == EvalMode::kHybrid) {
299-
output_k_cache_scales_.resize(num_layers);
300-
output_k_cache_zero_points_.resize(num_layers);
301-
output_v_cache_scales_.resize(num_layers);
302-
output_v_cache_zero_points_.resize(num_layers);
303-
for (int i = 0; i < num_layers; i++) {
304-
std::string get_k_scale_output_name =
305-
"get_k_scale_output_" + std::to_string(i);
306-
std::string get_k_zero_point_output_name =
307-
"get_k_zero_point_output_" + std::to_string(i);
308-
std::string get_v_scale_output_name =
309-
"get_v_scale_output_" + std::to_string(i);
310-
std::string get_v_zero_point_output_name =
311-
"get_v_zero_point_output_" + std::to_string(i);
312-
313-
if (module_->method_names()->count(get_k_scale_output_name) > 0) {
314-
output_k_cache_scales_[i] = static_cast<float>(
315-
ET_UNWRAP(module_->get(get_k_scale_output_name)).toDouble());
316-
} else {
317-
ET_LOG(Error, "Cannot find method %s", get_k_scale_output_name.c_str());
318-
return Error::Internal;
319-
}
320-
if (module_->method_names()->count(get_k_zero_point_output_name) > 0) {
321-
output_k_cache_zero_points_[i] = static_cast<T>(
322-
ET_UNWRAP(module_->get(get_k_zero_point_output_name)).toInt());
323-
} else {
324-
ET_LOG(
325-
Error,
326-
"Cannot find method %s",
327-
get_k_zero_point_output_name.c_str());
328-
return Error::Internal;
329-
}
330-
if (module_->method_names()->count(get_v_scale_output_name) > 0) {
331-
output_v_cache_scales_[i] = static_cast<float>(
332-
ET_UNWRAP(module_->get(get_v_scale_output_name)).toDouble());
333-
} else {
334-
ET_LOG(Error, "Cannot find method %s", get_v_scale_output_name.c_str());
335-
return Error::Internal;
336-
}
337-
if (module_->method_names()->count(get_v_zero_point_output_name) > 0) {
338-
output_v_cache_zero_points_[i] = static_cast<T>(
339-
ET_UNWRAP(module_->get(get_v_zero_point_output_name)).toInt());
340-
} else {
341-
ET_LOG(
342-
Error,
343-
"Cannot find method %s",
344-
get_v_zero_point_output_name.c_str());
345-
return Error::Internal;
346-
}
347-
}
348-
// Load scale and zero point for quantized input KV cache
349-
input_k_cache_scales_.resize(num_layers);
350-
input_k_cache_zero_points_.resize(num_layers);
351-
input_v_cache_scales_.resize(num_layers);
352-
input_v_cache_zero_points_.resize(num_layers);
353-
for (int i = 0; i < num_layers; i++) {
354-
std::string get_k_scale_input_name =
355-
"get_k_scale_input_" + std::to_string(i);
356-
std::string get_k_zero_point_input_name =
357-
"get_k_zero_point_input_" + std::to_string(i);
358-
std::string get_v_scale_input_name =
359-
"get_v_scale_input_" + std::to_string(i);
360-
std::string get_v_zero_point_input_name =
361-
"get_v_zero_point_input_" + std::to_string(i);
362-
if (module_->method_names()->count(get_k_scale_input_name) > 0) {
363-
input_k_cache_scales_[i] = static_cast<float>(
364-
ET_UNWRAP(module_->get(get_k_scale_input_name)).toDouble());
365-
} else {
366-
ET_LOG(Error, "Cannot find method %s", get_k_scale_input_name.c_str());
367-
return Error::Internal;
368-
}
369-
if (module_->method_names()->count(get_k_zero_point_input_name) > 0) {
370-
input_k_cache_zero_points_[i] = static_cast<T>(
371-
ET_UNWRAP(module_->get(get_k_zero_point_input_name)).toInt());
372-
} else {
373-
ET_LOG(
374-
Error,
375-
"Cannot find method %s",
376-
get_k_zero_point_input_name.c_str());
377-
return Error::Internal;
378-
}
379-
if (module_->method_names()->count(get_v_scale_input_name) > 0) {
380-
input_v_cache_scales_[i] = static_cast<float>(
381-
ET_UNWRAP(module_->get(get_v_scale_input_name)).toDouble());
382-
} else {
383-
ET_LOG(Error, "Cannot find method %s", get_v_scale_input_name.c_str());
384-
return Error::Internal;
385-
}
386-
if (module_->method_names()->count(get_v_zero_point_input_name) > 0) {
387-
input_v_cache_zero_points_[i] = static_cast<T>(
388-
ET_UNWRAP(module_->get(get_v_zero_point_input_name)).toInt());
389-
} else {
390-
ET_LOG(
391-
Error,
392-
"Cannot find method %s",
393-
get_v_zero_point_input_name.c_str());
394-
return Error::Internal;
395-
}
396-
}
397-
}
398-
399297
// Initialize EmbeddingGenerator
400298
embedding_generator_ = std::make_unique<EmbeddingProcessor>(
401299
embedding_runner_.get(),
@@ -599,46 +497,6 @@ Error MultimodalRunner<T>::generate_from_prompt_or_file(
599497
// start the main loop
600498
prompt_tokens.push_back(cur_token);
601499

602-
// Requant kv cache for prefill decode I/O
603-
if (eval_mode_ == EvalMode::kLookaheadDecoding ||
604-
eval_mode_ == EvalMode::kHybrid) {
605-
int64_t num_heads = prompt_processor_->get_num_heads();
606-
int64_t num_layers = prompt_processor_->get_num_layers();
607-
int64_t head_dim = kv_manager_->get_head_dim();
608-
std::vector<KVCache<T>> k_cache_ptrs = kv_manager_->get_k_cache_();
609-
std::vector<KVCache<T>> v_cache_ptrs = kv_manager_->get_v_cache_();
610-
611-
const int64_t num_elems_per_layer =
612-
(context_len_ - 1) * num_heads * head_dim;
613-
// Requant kv cache from prefill output scale/zero_point to decode input
614-
// scale/zero_point
615-
for (int layer_idx = 0; layer_idx < num_layers; layer_idx++) {
616-
T* k_cache_data = k_cache_ptrs[layer_idx].buffer;
617-
T* v_cache_data = v_cache_ptrs[layer_idx].buffer;
618-
619-
const float scale_ratio_k =
620-
output_k_cache_scales_[layer_idx] / input_k_cache_scales_[layer_idx];
621-
const float scale_ratio_v =
622-
output_v_cache_scales_[layer_idx] / input_v_cache_scales_[layer_idx];
623-
624-
for (int64_t i = 0; i < num_elems_per_layer; i++) {
625-
// Requant k_cache_data from prefill output scale/zero_point to decode
626-
// input scale/zero_point
627-
k_cache_data[i] = static_cast<T>(
628-
(k_cache_data[i] - output_k_cache_zero_points_[layer_idx]) *
629-
scale_ratio_k +
630-
input_k_cache_zero_points_[layer_idx]);
631-
632-
// Requant v_cache_data from prefill output scale/zero_point to decode
633-
// input scale/zero_point
634-
v_cache_data[i] = static_cast<T>(
635-
(v_cache_data[i] - output_v_cache_zero_points_[layer_idx]) *
636-
scale_ratio_v +
637-
input_v_cache_zero_points_[layer_idx]);
638-
}
639-
}
640-
}
641-
642500
int64_t num_generated_tokens = ET_UNWRAP(token_generator_->generate(
643501
prompt_tokens, cur_pos_, seq_len, token_callback, dump_logits, nullptr));
644502
stats_.inference_end_ms = time_in_ms();

examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,6 @@ class MultimodalRunner : public executorch::extension::llm::IRunner {
141141
multimodal_embeddings_dim_order_;
142142
TensorStruct<float> merged_embeddings_;
143143

144-
// scale and zero point for quantized KV cache
145-
std::vector<float> input_k_cache_scales_;
146-
std::vector<T> input_k_cache_zero_points_;
147-
std::vector<float> input_v_cache_scales_;
148-
std::vector<T> input_v_cache_zero_points_;
149-
std::vector<float> output_k_cache_scales_;
150-
std::vector<T> output_k_cache_zero_points_;
151-
std::vector<float> output_v_cache_scales_;
152-
std::vector<T> output_v_cache_zero_points_;
153-
154144
// stats
155145
executorch::llm::Stats stats_;
156146
};

examples/qualcomm/oss_scripts/llama/wrappers/base_component.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,6 @@ class Mode(Enum):
4242
DECODE = 2
4343

4444

45-
def is_node_src_start_with_name(node: torch.fx.Node, prefix: str) -> bool:
46-
"""
47-
Return True if any NodeSource in node.meta['from_node']
48-
has a `name` starting with `prefix`.
49-
"""
50-
51-
def has_source_name_prefix(
52-
node_src: torch.fx.traceback.NodeSource, prefix: str
53-
) -> bool:
54-
55-
name = getattr(node_src, "name", None)
56-
if isinstance(name, str) and name.startswith(prefix):
57-
return True
58-
59-
children = getattr(node_src, "from_node", None)
60-
if not children:
61-
return False
62-
63-
for src in children:
64-
if has_source_name_prefix(src, prefix):
65-
return True
66-
67-
return False
68-
69-
node_srcs = node.meta.get("from_node", None)
70-
if not node_srcs:
71-
return False
72-
73-
return any(has_source_name_prefix(node_src, prefix) for node_src in node_srcs)
74-
75-
7645
def log_info(func):
7746
class TimeIt:
7847
def __init__(self, event):

0 commit comments

Comments
 (0)