Skip to content

Commit 0795efb

Browse files
Qualcomm AI Engine Direct - [Multimodal] granite-speech-3.3-2b (#18740)
Summary: - Support granite-speech-3.3-2b - Extend Audio modality in QNNMultimodal AOT flow - Extend Audio modality in QNNMultimodal runner - Support encoder model sharding Pull Request resolved: #18740 Test Plan: #### CI ``` bash python -m backends.qualcomm.tests.test_qnn_delegate TestExampleMultimodalityScript.test_static_asr --model_name granite_speech_3_3-2b build-android --executorch_root . -a . -m SM8750 -s ${SERIAL_NUM} ``` #### Script ```bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m SM8750 --decoder_model granite_speech_3_3-2b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "can you transcribe the speech into a written format?" --audio_path "https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true" ``` Audio file: https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true Prompt: "can you transcribe the speech into a written format?" Result ``` bash I 00:00:16.333997 executorch:multimodal_runner.cpp:542] RSS after finishing text generation: 614.941406 MiB (0 if unsupported) I 00:00:16.334231 executorch:stats.h:161] Prompt Tokens: 212 Generated Tokens: 201 I 00:00:16.334356 executorch:stats.h:167] Model Load Time: 1.460000 (seconds) I 00:00:16.334419 executorch:stats.h:177] Total inference time: 14.871000 (seconds) Rate: 13.516240 (tokens/second) I 00:00:16.334480 executorch:stats.h:185] Prompt evaluation: 0.798000 (seconds) Rate: 265.664160 (tokens/second) I 00:00:16.334541 executorch:stats.h:196] Generated 201 tokens: 14.073000 (seconds) Rate: 14.282669 (tokens/second) I 00:00:16.334629 executorch:stats.h:204] Time to first generated token: 0.798000 (seconds) I 00:00:16.334688 executorch:stats.h:211] Sampling time over 413 tokens: 0.479000 (seconds) [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn device PyTorchObserver {"prefill_token_per_sec":265.664,"decode_token_per_sec":14.2827,"prompt_tokens":212,"generated_tokens":201,"model_load_start_ms":1744743525724,"model_load_end_ms":1744743527184,"inference_start_ms":1744743527186,"inference_end_ms":1744743542057,"prompt_eval_end_ms":1744743527984,"first_token_ms":1744743527984,"aggregate_sampling_time_ms":479,"SCALING_FACTOR_UNITS_PER_SECOND":1000} [INFO] [Qnn ExecuTorch]: Destroy Qnn backend /data/local/tmp/yuyazhua/executorch/static_llm/outputs/outputs.txt: 1 file pulled. 0.9 MB/s (1170 bytes in 0.001s) /data/local/tmp/yuyazhua/executorch/static_llm/outputs/inference_speed.txt: 1 file pulled. 0.0 MB/s (7 bytes in 0.002s) [INFO 2026-04-08 00:22:11,849 llama.py:243] Device Inference Results[0]: <|start_of_role|>system<|end_of_role|>You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|> <|start_of_role|>user<|end_of_role|>can you transcribe the speech into a written format?<|end_of_text|> <|start_of_role|>assistant<|end_of_role|>It appears you've provided a fragment of a sentence, possibly from a poem or text, and you're asking for a transcription or translation into written format. However, without the complete context or original text, it's challenging to accurately transcribe or translate it. If we were to proceed with a hypothetical example, here's a possible continuation of the sentence in a written format: "After his nap, Timothy leisurely stretched his foot, first one then the other, carefully selecting the choicest bits. Turning over the food, he methodically picked out the desired portions, meticulously choosing what was to be included in his meal." This continuation assumes a narrative style, where Timothy is taking care of food preparation. The original sentence seems to be a playful or poetic exploration of a character's actions, possibly related to food preparation or a cooking process.<|end_of_text|> ``` cc: abhinaykukkadapu, cccclai, haowhsu-quic Differential Revision: D101574849 Pulled By: abhinaykukkadapu
1 parent 069a793 commit 0795efb

37 files changed

Lines changed: 1643 additions & 535 deletions

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 87 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -116,86 +116,98 @@ def call(self, graph_module: torch.fx.GraphModule):
116116
)
117117

118118
with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
119-
filter_arg = node.args[1]
120-
filter_node = (
121-
filter_arg
122-
if filter_arg.op == "placeholder"
123-
else node.args[1].args[0]
124-
)
125-
filter_node.meta["val"] = (
126-
filter_node.meta["val"].unsqueeze(2).contiguous()
127-
)
128-
filter_tensor = get_parameter(
129-
filter_node, self.edge_program
130-
).unsqueeze(2)
131-
set_parameter(
132-
(
133-
torch.nn.Parameter(filter_tensor)
134-
if filter_tensor.dtype == torch.float
135-
else filter_tensor
136-
),
137-
filter_node,
138-
self.edge_program,
139-
)
140-
141-
num_args = len(node.args)
142-
143-
bias_node = node.args[2] if num_args > 2 else None
144-
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
145-
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
146-
if node.target == torch.ops.aten.conv1d.default:
147-
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
148-
groups = node.args[6] if num_args > 6 else 1
149-
conv_args = (
150-
qdq_node_after_unsqueeze,
151-
node.args[1],
152-
bias_node,
153-
stride,
154-
padding,
155-
dilation,
156-
groups,
119+
# conv2d must be inserted before conv1d in the graph to preserve correct
120+
# topological ordering. This is required due to conv-bn fusion: when conv1d
121+
# has no bias, the fused bias (from batchnorm) is introduced as a new node,
122+
# and its corresponding dq (dequantize) node must appear before conv2d in
123+
# the execution order.
124+
with graph_module.graph.inserting_before(node):
125+
filter_arg = node.args[1]
126+
filter_node = (
127+
filter_arg
128+
if filter_arg.op == "placeholder"
129+
else node.args[1].args[0]
157130
)
158-
else:
159-
output_padding = (
160-
[0] + node.args[5] if num_args > 5 else [0, 0]
131+
filter_node.meta["val"] = filter_node.meta["val"].unsqueeze(
132+
2
161133
)
162-
groups = node.args[6] if num_args > 6 else 1
163-
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
164-
conv_args = (
165-
qdq_node_after_unsqueeze,
166-
node.args[1],
167-
bias_node,
168-
stride,
169-
padding,
170-
output_padding,
171-
groups,
172-
dilation,
173-
)
174-
conv2d_node = graph.create_node(
175-
"call_function",
176-
self.conv1d_op_map[node.target],
177-
conv_args,
178-
)
179-
conv2d_node.meta = copy_meta(
180-
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
181-
)
182-
qdq_node_after_conv2d = append_qdq(
183-
graph_module=graph_module,
184-
node=conv2d_node,
185-
qdq_node=list(node.users)[0],
186-
)
187-
188-
with graph_module.graph.inserting_after(qdq_node_after_conv2d):
189-
squeeze_op = torch.ops.aten.squeeze_copy.dims
190-
squeeze_node = graph.create_node(
191-
"call_function",
192-
squeeze_op,
134+
filter_tensor = get_parameter(
135+
filter_node, self.edge_program
136+
).unsqueeze(2)
137+
set_parameter(
193138
(
194-
qdq_node_after_conv2d,
195-
[2],
139+
torch.nn.Parameter(filter_tensor)
140+
if filter_tensor.dtype == torch.float
141+
else filter_tensor
196142
),
143+
filter_node,
144+
self.edge_program,
145+
)
146+
147+
num_args = len(node.args)
148+
149+
bias_node = node.args[2] if num_args > 2 else None
150+
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
151+
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
152+
if node.target == torch.ops.aten.conv1d.default:
153+
dilation = (
154+
[1] + node.args[5] if num_args > 5 else [1, 1]
155+
)
156+
groups = node.args[6] if num_args > 6 else 1
157+
conv_args = (
158+
qdq_node_after_unsqueeze,
159+
node.args[1],
160+
bias_node,
161+
stride,
162+
padding,
163+
dilation,
164+
groups,
165+
)
166+
else:
167+
output_padding = (
168+
[0] + node.args[5] if num_args > 5 else [0, 0]
169+
)
170+
groups = node.args[6] if num_args > 6 else 1
171+
dilation = (
172+
[1] + node.args[7] if num_args > 7 else [1, 1]
173+
)
174+
conv_args = (
175+
qdq_node_after_unsqueeze,
176+
node.args[1],
177+
bias_node,
178+
stride,
179+
padding,
180+
output_padding,
181+
groups,
182+
dilation,
183+
)
184+
conv2d_node = graph.create_node(
185+
"call_function",
186+
self.conv1d_op_map[node.target],
187+
conv_args,
188+
)
189+
conv2d_node.meta = copy_meta(
190+
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
197191
)
198-
squeeze_node.meta = copy_meta(node.meta)
192+
qdq_node_after_conv2d = append_qdq(
193+
graph_module=graph_module,
194+
node=conv2d_node,
195+
qdq_node=list(node.users)[0],
196+
)
197+
198+
with graph_module.graph.inserting_after(
199+
qdq_node_after_conv2d
200+
):
201+
squeeze_op = torch.ops.aten.squeeze_copy.dims
202+
squeeze_node = graph.create_node(
203+
"call_function",
204+
squeeze_op,
205+
(
206+
qdq_node_after_conv2d,
207+
[2],
208+
),
209+
)
210+
squeeze_node.meta = copy_meta(node.meta)
199211

200212
for user in node.users.copy():
201213
user.replace_input_with(node, squeeze_node)

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,24 @@ def example_inputs(self):
524524
}
525525

526526

527+
class Conv1dBn(torch.nn.Module):
528+
def __init__(self, bias=True):
529+
super().__init__()
530+
self.conv = torch.nn.Conv1d(
531+
in_channels=2048,
532+
out_channels=2048,
533+
kernel_size=15,
534+
groups=2048,
535+
bias=bias,
536+
)
537+
self.batch_norm = torch.nn.BatchNorm1d(2048)
538+
539+
def forward(self, x):
540+
x = self.conv(x)
541+
x = self.batch_norm(x)
542+
return x
543+
544+
527545
class Conv1dSequential(torch.nn.Module):
528546
def __init__(self, bias=True):
529547
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,13 @@ def test_qnn_backend_conv1d(self):
414414
with self.subTest(i=i):
415415
self.lower_module_and_test_output(module, sample_input)
416416

417+
def test_qnn_conv1d_batch_norm(self):
418+
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
419+
sample_input = (torch.randn([1, 2048, 858]),)
420+
for i, module in enumerate(modules):
421+
with self.subTest(i=i):
422+
self.lower_module_and_test_output(module, sample_input)
423+
417424
def test_qnn_backend_conv2d(self):
418425
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
419426
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -2809,6 +2816,14 @@ def test_qnn_backend_conv1d(self):
28092816
module = self.get_qdq_module(module, sample_input)
28102817
self.lower_module_and_test_output(module, sample_input)
28112818

2819+
def test_qnn_conv1d_batch_norm(self):
2820+
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
2821+
sample_input = (torch.randn([1, 2048, 858]),)
2822+
for i, module in enumerate(modules):
2823+
with self.subTest(i=i):
2824+
module = self.get_qdq_module(module, sample_input)
2825+
self.lower_module_and_test_output(module, sample_input)
2826+
28122827
def test_qnn_backend_conv2d(self):
28132828
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
28142829
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -7239,13 +7254,30 @@ class MLLMSpecs:
72397254
tok_embedding_pte_size: float
72407255
decoder_pte_size: float
72417256

7257+
@dataclass(frozen=True)
7258+
class ALMSpecs(MLLMSpecs):
7259+
audio_path: str
7260+
golden_audio_feature: str
7261+
72427262
@dataclass(frozen=True)
72437263
class VLMSpecs(MLLMSpecs):
72447264
image_path: str
72457265
golden_image_feature: str
72467266

72477267
# TODO: refactor to support different backends
72487268
def setUp(self):
7269+
self.alm_specs = {
7270+
"granite_speech_3_3-2b": TestExampleMultimodalityScript.ALMSpecs(
7271+
max_seq_len=512,
7272+
sm8650_token_rate=5,
7273+
sm8750_token_rate=8,
7274+
encoder_pte_size=900_000_000, # 900MB
7275+
tok_embedding_pte_size=240_000_000, # 240MB
7276+
decoder_pte_size=3_000_000_000, # 3GB
7277+
audio_path="https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true", # Audio content: after his nap,...
7278+
golden_audio_feature="after his nap,",
7279+
),
7280+
}
72497281
self.vlm_specs = {
72507282
"smolvlm_500m_instruct": TestExampleMultimodalityScript.VLMSpecs(
72517283
max_seq_len=128,
@@ -7269,6 +7301,96 @@ def setUp(self):
72697301
),
72707302
}
72717303

7304+
def test_static_asr(self):
7305+
if not self.required_envs([self.model_name]):
7306+
self.skipTest("missing required envs")
7307+
7308+
if self.enable_x86_64:
7309+
# Running on host is extremely slow for large models, so we skip this check to avoid timeouts.
7310+
# Please verify the output on the actual device instead.
7311+
self.skipTest(
7312+
"Skipping the check for the static ASR model on x86 due to long execution time."
7313+
)
7314+
7315+
alm_specs: TestExampleMultimodalityScript.ALMSpecs = self.alm_specs[
7316+
self.model_name
7317+
]
7318+
prompt = "can you transcribe the speech into a written format?"
7319+
audio_path = alm_specs.audio_path
7320+
cmds = [
7321+
"python",
7322+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
7323+
"--artifact",
7324+
self.artifact_dir,
7325+
"--build_folder",
7326+
self.build_folder,
7327+
"--soc_model",
7328+
self.soc_model,
7329+
"--ip",
7330+
self.ip,
7331+
"--port",
7332+
str(self.port),
7333+
"--prompt",
7334+
prompt,
7335+
"--audio_path",
7336+
audio_path,
7337+
"--temperature",
7338+
"0",
7339+
"--decoder_model",
7340+
f"{self.model_name}",
7341+
"--model_mode",
7342+
"kv",
7343+
"--max_seq_len",
7344+
f"{alm_specs.max_seq_len}",
7345+
]
7346+
if self.compile_only:
7347+
cmds.extend(["--compile_only"])
7348+
elif self.device:
7349+
cmds.extend(["--device", self.device])
7350+
if self.host:
7351+
cmds.extend(["--host", self.host])
7352+
if self.pre_gen_pte:
7353+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
7354+
7355+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
7356+
with Listener((self.ip, self.port)) as listener:
7357+
conn = listener.accept()
7358+
p.communicate()
7359+
msg = json.loads(conn.recv())
7360+
if "Error" in msg:
7361+
self.fail(msg["Error"])
7362+
else:
7363+
if not self.compile_only:
7364+
model_out = msg["result"][0]
7365+
self.assertTrue(
7366+
alm_specs.golden_audio_feature in model_out.lower(),
7367+
f"Expected Output contains feature: '{alm_specs.golden_audio_feature}' Actual Output: '{model_out}'",
7368+
)
7369+
print(f"Audio Path: {audio_path}")
7370+
print(f"Query: {prompt}")
7371+
print(f"Answer: {model_out}")
7372+
7373+
encoder_pte_size = msg["audio_encoder_pte_size"]
7374+
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
7375+
decoder_pte_size = msg["pte_size"]
7376+
self.assertLessEqual(encoder_pte_size, alm_specs.encoder_pte_size)
7377+
self.assertLessEqual(
7378+
tok_embedding_pte_size, alm_specs.tok_embedding_pte_size
7379+
)
7380+
self.assertLessEqual(decoder_pte_size, alm_specs.decoder_pte_size)
7381+
print(f"Encoder PTE Size: {encoder_pte_size} bytes")
7382+
print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes")
7383+
print(f"Text Decoder PTE Size: {decoder_pte_size} bytes")
7384+
7385+
attr_name = f"{self.soc_model.lower()}_token_rate"
7386+
if not self.compile_only and hasattr(alm_specs, attr_name):
7387+
device_inference_speed = msg["inference_speed"]
7388+
expected_inference_speed = getattr(alm_specs, attr_name)
7389+
print(f"Prompt Evaluation: {device_inference_speed} tokens/second")
7390+
self.assertGreaterEqual(
7391+
device_inference_speed, expected_inference_speed
7392+
)
7393+
72727394
def test_static_vlm(self):
72737395
if not self.required_envs([self.model_name]):
72747396
self.skipTest("missing required envs")
@@ -7333,7 +7455,7 @@ def test_static_vlm(self):
73337455
print(f"Query: {prompt}")
73347456
print(f"Answer: {model_out}")
73357457
if not self.enable_x86_64:
7336-
encoder_pte_size = msg["encoder_pte_size"]
7458+
encoder_pte_size = msg["vision_encoder_pte_size"]
73377459
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
73387460
decoder_pte_size = msg["pte_size"]
73397461
self.assertLessEqual(encoder_pte_size, vlm_specs.encoder_pte_size)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
oncall("executorch")
5+
6+
fbcode_target(_kind = runtime.python_library,
7+
name = "granite_speech",
8+
srcs = [
9+
"__init__.py",
10+
"convert_weights.py",
11+
],
12+
_is_external_target = True,
13+
base_module = "executorch.examples.models.granite_speech",
14+
resources = {
15+
"config/2b_config.json": "config/2b_config.json",
16+
},
17+
deps = [
18+
"//caffe2:torch",
19+
"//executorch/examples/models/llama:llama2_model",
20+
"fbcode//pytorch/torchtune:lib",
21+
"fbsource//third-party/pypi/safetensors:safetensors",
22+
],
23+
visibility = ["PUBLIC"],
24+
)

0 commit comments

Comments
 (0)