Skip to content

Commit 2e4c9bc

Browse files
Qualcomm AI Engine Direct - [Multimodal] granite-3.3-2b-instruct
Summary: - Support granite-speech-3.3-2b - Extend Audio modality in QNNMultimodal AOT flow - Extend Audio modality in QNNMultimodal runner - Support encoder model sharding
1 parent e281726 commit 2e4c9bc

37 files changed

Lines changed: 1626 additions & 535 deletions

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 87 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -116,86 +116,98 @@ def call(self, graph_module: torch.fx.GraphModule):
116116
)
117117

118118
with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
119-
filter_arg = node.args[1]
120-
filter_node = (
121-
filter_arg
122-
if filter_arg.op == "placeholder"
123-
else node.args[1].args[0]
124-
)
125-
filter_node.meta["val"] = (
126-
filter_node.meta["val"].unsqueeze(2).contiguous()
127-
)
128-
filter_tensor = get_parameter(
129-
filter_node, self.edge_program
130-
).unsqueeze(2)
131-
set_parameter(
132-
(
133-
torch.nn.Parameter(filter_tensor)
134-
if filter_tensor.dtype == torch.float
135-
else filter_tensor
136-
),
137-
filter_node,
138-
self.edge_program,
139-
)
140-
141-
num_args = len(node.args)
142-
143-
bias_node = node.args[2] if num_args > 2 else None
144-
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
145-
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
146-
if node.target == torch.ops.aten.conv1d.default:
147-
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
148-
groups = node.args[6] if num_args > 6 else 1
149-
conv_args = (
150-
qdq_node_after_unsqueeze,
151-
node.args[1],
152-
bias_node,
153-
stride,
154-
padding,
155-
dilation,
156-
groups,
119+
# conv2d must be inserted before conv1d in the graph to preserve correct
120+
# topological ordering. This is required due to conv-bn fusion: when conv1d
121+
# has no bias, the fused bias (from batchnorm) is introduced as a new node,
122+
# and its corresponding dq (dequantize) node must appear before conv2d in
123+
# the execution order.
124+
with graph_module.graph.inserting_before(node):
125+
filter_arg = node.args[1]
126+
filter_node = (
127+
filter_arg
128+
if filter_arg.op == "placeholder"
129+
else node.args[1].args[0]
157130
)
158-
else:
159-
output_padding = (
160-
[0] + node.args[5] if num_args > 5 else [0, 0]
131+
filter_node.meta["val"] = filter_node.meta["val"].unsqueeze(
132+
2
161133
)
162-
groups = node.args[6] if num_args > 6 else 1
163-
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
164-
conv_args = (
165-
qdq_node_after_unsqueeze,
166-
node.args[1],
167-
bias_node,
168-
stride,
169-
padding,
170-
output_padding,
171-
groups,
172-
dilation,
173-
)
174-
conv2d_node = graph.create_node(
175-
"call_function",
176-
self.conv1d_op_map[node.target],
177-
conv_args,
178-
)
179-
conv2d_node.meta = copy_meta(
180-
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
181-
)
182-
qdq_node_after_conv2d = append_qdq(
183-
graph_module=graph_module,
184-
node=conv2d_node,
185-
qdq_node=list(node.users)[0],
186-
)
187-
188-
with graph_module.graph.inserting_after(qdq_node_after_conv2d):
189-
squeeze_op = torch.ops.aten.squeeze_copy.dims
190-
squeeze_node = graph.create_node(
191-
"call_function",
192-
squeeze_op,
134+
filter_tensor = get_parameter(
135+
filter_node, self.edge_program
136+
).unsqueeze(2)
137+
set_parameter(
193138
(
194-
qdq_node_after_conv2d,
195-
[2],
139+
torch.nn.Parameter(filter_tensor)
140+
if filter_tensor.dtype == torch.float
141+
else filter_tensor
196142
),
143+
filter_node,
144+
self.edge_program,
145+
)
146+
147+
num_args = len(node.args)
148+
149+
bias_node = node.args[2] if num_args > 2 else None
150+
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
151+
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
152+
if node.target == torch.ops.aten.conv1d.default:
153+
dilation = (
154+
[1] + node.args[5] if num_args > 5 else [1, 1]
155+
)
156+
groups = node.args[6] if num_args > 6 else 1
157+
conv_args = (
158+
qdq_node_after_unsqueeze,
159+
node.args[1],
160+
bias_node,
161+
stride,
162+
padding,
163+
dilation,
164+
groups,
165+
)
166+
else:
167+
output_padding = (
168+
[0] + node.args[5] if num_args > 5 else [0, 0]
169+
)
170+
groups = node.args[6] if num_args > 6 else 1
171+
dilation = (
172+
[1] + node.args[7] if num_args > 7 else [1, 1]
173+
)
174+
conv_args = (
175+
qdq_node_after_unsqueeze,
176+
node.args[1],
177+
bias_node,
178+
stride,
179+
padding,
180+
output_padding,
181+
groups,
182+
dilation,
183+
)
184+
conv2d_node = graph.create_node(
185+
"call_function",
186+
self.conv1d_op_map[node.target],
187+
conv_args,
188+
)
189+
conv2d_node.meta = copy_meta(
190+
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
197191
)
198-
squeeze_node.meta = copy_meta(node.meta)
192+
qdq_node_after_conv2d = append_qdq(
193+
graph_module=graph_module,
194+
node=conv2d_node,
195+
qdq_node=list(node.users)[0],
196+
)
197+
198+
with graph_module.graph.inserting_after(
199+
qdq_node_after_conv2d
200+
):
201+
squeeze_op = torch.ops.aten.squeeze_copy.dims
202+
squeeze_node = graph.create_node(
203+
"call_function",
204+
squeeze_op,
205+
(
206+
qdq_node_after_conv2d,
207+
[2],
208+
),
209+
)
210+
squeeze_node.meta = copy_meta(node.meta)
199211

200212
for user in node.users.copy():
201213
user.replace_input_with(node, squeeze_node)

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,24 @@ def example_inputs(self):
524524
}
525525

526526

527+
class Conv1dBn(torch.nn.Module):
528+
def __init__(self, bias=True):
529+
super().__init__()
530+
self.conv = torch.nn.Conv1d(
531+
in_channels=2048,
532+
out_channels=2048,
533+
kernel_size=15,
534+
groups=2048,
535+
bias=bias,
536+
)
537+
self.batch_norm = torch.nn.BatchNorm1d(2048)
538+
539+
def forward(self, x):
540+
x = self.conv(x)
541+
x = self.batch_norm(x)
542+
return x
543+
544+
527545
class Conv1dSequential(torch.nn.Module):
528546
def __init__(self, bias=True):
529547
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,13 @@ def test_qnn_backend_conv1d(self):
414414
with self.subTest(i=i):
415415
self.lower_module_and_test_output(module, sample_input)
416416

417+
def test_qnn_conv1d_batch_norm(self):
418+
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
419+
sample_input = (torch.randn([1, 2048, 858]),)
420+
for i, module in enumerate(modules):
421+
with self.subTest(i=i):
422+
self.lower_module_and_test_output(module, sample_input)
423+
417424
def test_qnn_backend_conv2d(self):
418425
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
419426
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -2809,6 +2816,14 @@ def test_qnn_backend_conv1d(self):
28092816
module = self.get_qdq_module(module, sample_input)
28102817
self.lower_module_and_test_output(module, sample_input)
28112818

2819+
def test_qnn_conv1d_batch_norm(self):
2820+
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
2821+
sample_input = (torch.randn([1, 2048, 858]),)
2822+
for i, module in enumerate(modules):
2823+
with self.subTest(i=i):
2824+
module = self.get_qdq_module(module, sample_input)
2825+
self.lower_module_and_test_output(module, sample_input)
2826+
28122827
def test_qnn_backend_conv2d(self):
28132828
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
28142829
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -7239,13 +7254,30 @@ class MLLMSpecs:
72397254
tok_embedding_pte_size: float
72407255
decoder_pte_size: float
72417256

7257+
@dataclass(frozen=True)
7258+
class ALMSpecs(MLLMSpecs):
7259+
audio_path: str
7260+
golden_audio_feature: str
7261+
72427262
@dataclass(frozen=True)
72437263
class VLMSpecs(MLLMSpecs):
72447264
image_path: str
72457265
golden_image_feature: str
72467266

72477267
# TODO: refactor to support different backends
72487268
def setUp(self):
7269+
self.alm_specs = {
7270+
"granite_speech_3_3-2b": TestExampleMultimodalityScript.ALMSpecs(
7271+
max_seq_len=512,
7272+
sm8650_token_rate=5,
7273+
sm8750_token_rate=8,
7274+
encoder_pte_size=900_000_000, # 900MB
7275+
tok_embedding_pte_size=240_000_000, # 240MB
7276+
decoder_pte_size=3_000_000_000, # 3GB
7277+
audio_path="https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true", # Audio content: after his nap,...
7278+
golden_audio_feature="after his nap,",
7279+
),
7280+
}
72497281
self.vlm_specs = {
72507282
"smolvlm_500m_instruct": TestExampleMultimodalityScript.VLMSpecs(
72517283
max_seq_len=128,
@@ -7269,6 +7301,96 @@ def setUp(self):
72697301
),
72707302
}
72717303

7304+
def test_static_asr(self):
7305+
if not self.required_envs([self.model_name]):
7306+
self.skipTest("missing required envs")
7307+
7308+
if self.enable_x86_64:
7309+
# Running on host is extremely slow for large models, so we skip this check to avoid timeouts.
7310+
# Please verify the output on the actual device instead.
7311+
self.skipTest(
7312+
"Skipping the check for the static ASR model on x86 due to long execution time."
7313+
)
7314+
7315+
alm_specs: TestExampleMultimodalityScript.ALMSpecs = self.alm_specs[
7316+
self.model_name
7317+
]
7318+
prompt = "can you transcribe the speech into a written format?"
7319+
audio_path = alm_specs.audio_path
7320+
cmds = [
7321+
"python",
7322+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
7323+
"--artifact",
7324+
self.artifact_dir,
7325+
"--build_folder",
7326+
self.build_folder,
7327+
"--soc_model",
7328+
self.soc_model,
7329+
"--ip",
7330+
self.ip,
7331+
"--port",
7332+
str(self.port),
7333+
"--prompt",
7334+
prompt,
7335+
"--audio_path",
7336+
audio_path,
7337+
"--temperature",
7338+
"0",
7339+
"--decoder_model",
7340+
f"{self.model_name}",
7341+
"--model_mode",
7342+
"kv",
7343+
"--max_seq_len",
7344+
f"{alm_specs.max_seq_len}",
7345+
]
7346+
if self.compile_only:
7347+
cmds.extend(["--compile_only"])
7348+
elif self.device:
7349+
cmds.extend(["--device", self.device])
7350+
if self.host:
7351+
cmds.extend(["--host", self.host])
7352+
if self.pre_gen_pte:
7353+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
7354+
7355+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
7356+
with Listener((self.ip, self.port)) as listener:
7357+
conn = listener.accept()
7358+
p.communicate()
7359+
msg = json.loads(conn.recv())
7360+
if "Error" in msg:
7361+
self.fail(msg["Error"])
7362+
else:
7363+
if not self.compile_only:
7364+
model_out = msg["result"][0]
7365+
self.assertTrue(
7366+
alm_specs.golden_audio_feature in model_out.lower(),
7367+
f"Expected Output contains feature: '{alm_specs.golden_audio_feature}' Actual Output: '{model_out}'",
7368+
)
7369+
print(f"Audio Path: {audio_path}")
7370+
print(f"Query: {prompt}")
7371+
print(f"Answer: {model_out}")
7372+
7373+
encoder_pte_size = msg["audio_encoder_pte_size"]
7374+
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
7375+
decoder_pte_size = msg["pte_size"]
7376+
self.assertLessEqual(encoder_pte_size, alm_specs.encoder_pte_size)
7377+
self.assertLessEqual(
7378+
tok_embedding_pte_size, alm_specs.tok_embedding_pte_size
7379+
)
7380+
self.assertLessEqual(decoder_pte_size, alm_specs.decoder_pte_size)
7381+
print(f"Encoder PTE Size: {encoder_pte_size} bytes")
7382+
print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes")
7383+
print(f"Text Decoder PTE Size: {decoder_pte_size} bytes")
7384+
7385+
attr_name = f"{self.soc_model.lower()}_token_rate"
7386+
if not self.compile_only and hasattr(alm_specs, attr_name):
7387+
device_inference_speed = msg["inference_speed"]
7388+
expected_inference_speed = getattr(alm_specs, attr_name)
7389+
print(f"Prompt Evaluation: {device_inference_speed} tokens/second")
7390+
self.assertGreaterEqual(
7391+
device_inference_speed, expected_inference_speed
7392+
)
7393+
72727394
def test_static_vlm(self):
72737395
if not self.required_envs([self.model_name]):
72747396
self.skipTest("missing required envs")
@@ -7333,7 +7455,7 @@ def test_static_vlm(self):
73337455
print(f"Query: {prompt}")
73347456
print(f"Answer: {model_out}")
73357457
if not self.enable_x86_64:
7336-
encoder_pte_size = msg["encoder_pte_size"]
7458+
encoder_pte_size = msg["vision_encoder_pte_size"]
73377459
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
73387460
decoder_pte_size = msg["pte_size"]
73397461
self.assertLessEqual(encoder_pte_size, vlm_specs.encoder_pte_size)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
oncall("executorch")
5+
6+
fbcode_target(_kind = runtime.python_library,
7+
name = "granite_speech",
8+
srcs = [
9+
"__init__.py",
10+
"convert_weights.py",
11+
],
12+
_is_external_target = True,
13+
base_module = "executorch.examples.models.granite_speech",
14+
resources = {
15+
"config/2b_config.json": "config/2b_config.json",
16+
},
17+
deps = [
18+
"//caffe2:torch",
19+
"//executorch/examples/models/llama:llama2_model",
20+
"fbcode//pytorch/torchtune:lib",
21+
"fbsource//third-party/pypi/safetensors:safetensors",
22+
],
23+
visibility = ["PUBLIC"],
24+
)

0 commit comments

Comments
 (0)