Skip to content

Commit b01dff2

Browse files
Qualcomm AI Engine Direct - [Multimodal] granite-3.3-2b-instruct
Summary: - Support granite-speech-3.3-2b - Extend Audio modality in QNNMultimodal AOT flow - Extend Audio modality in QNNMultimodal runner - Support encoder model sharding
1 parent fcccda3 commit b01dff2

37 files changed

Lines changed: 1626 additions & 534 deletions

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 87 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -115,86 +115,98 @@ def call(self, graph_module: torch.fx.GraphModule):
115115
)
116116

117117
with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
118-
filter_arg = node.args[1]
119-
filter_node = (
120-
filter_arg
121-
if filter_arg.op == "placeholder"
122-
else node.args[1].args[0]
123-
)
124-
filter_node.meta["val"] = (
125-
filter_node.meta["val"].unsqueeze(2).contiguous()
126-
)
127-
filter_tensor = get_parameter(
128-
filter_node, self.edge_program
129-
).unsqueeze(2)
130-
set_parameter(
131-
(
132-
torch.nn.Parameter(filter_tensor)
133-
if filter_tensor.dtype == torch.float
134-
else filter_tensor
135-
),
136-
filter_node,
137-
self.edge_program,
138-
)
139-
140-
num_args = len(node.args)
141-
142-
bias_node = node.args[2] if num_args > 2 else None
143-
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
144-
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
145-
if node.target == torch.ops.aten.conv1d.default:
146-
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
147-
groups = node.args[6] if num_args > 6 else 1
148-
conv_args = (
149-
qdq_node_after_unsqueeze,
150-
node.args[1],
151-
bias_node,
152-
stride,
153-
padding,
154-
dilation,
155-
groups,
118+
# conv2d must be inserted before conv1d in the graph to preserve correct
119+
# topological ordering. This is required due to conv-bn fusion: when conv1d
120+
# has no bias, the fused bias (from batchnorm) is introduced as a new node,
121+
# and its corresponding dq (dequantize) node must appear before conv2d in
122+
# the execution order.
123+
with graph_module.graph.inserting_before(node):
124+
filter_arg = node.args[1]
125+
filter_node = (
126+
filter_arg
127+
if filter_arg.op == "placeholder"
128+
else node.args[1].args[0]
156129
)
157-
else:
158-
output_padding = (
159-
[0] + node.args[5] if num_args > 5 else [0, 0]
130+
filter_node.meta["val"] = filter_node.meta["val"].unsqueeze(
131+
2
160132
)
161-
groups = node.args[6] if num_args > 6 else 1
162-
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
163-
conv_args = (
164-
qdq_node_after_unsqueeze,
165-
node.args[1],
166-
bias_node,
167-
stride,
168-
padding,
169-
output_padding,
170-
groups,
171-
dilation,
172-
)
173-
conv2d_node = graph.create_node(
174-
"call_function",
175-
self.conv1d_op_map[node.target],
176-
conv_args,
177-
)
178-
conv2d_node.meta = copy_meta(
179-
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
180-
)
181-
qdq_node_after_conv2d = append_qdq(
182-
graph_module=graph_module,
183-
node=conv2d_node,
184-
qdq_node=list(node.users)[0],
185-
)
186-
187-
with graph_module.graph.inserting_after(qdq_node_after_conv2d):
188-
squeeze_op = torch.ops.aten.squeeze_copy.dims
189-
squeeze_node = graph.create_node(
190-
"call_function",
191-
squeeze_op,
133+
filter_tensor = get_parameter(
134+
filter_node, self.edge_program
135+
).unsqueeze(2)
136+
set_parameter(
192137
(
193-
qdq_node_after_conv2d,
194-
[2],
138+
torch.nn.Parameter(filter_tensor)
139+
if filter_tensor.dtype == torch.float
140+
else filter_tensor
195141
),
142+
filter_node,
143+
self.edge_program,
144+
)
145+
146+
num_args = len(node.args)
147+
148+
bias_node = node.args[2] if num_args > 2 else None
149+
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
150+
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
151+
if node.target == torch.ops.aten.conv1d.default:
152+
dilation = (
153+
[1] + node.args[5] if num_args > 5 else [1, 1]
154+
)
155+
groups = node.args[6] if num_args > 6 else 1
156+
conv_args = (
157+
qdq_node_after_unsqueeze,
158+
node.args[1],
159+
bias_node,
160+
stride,
161+
padding,
162+
dilation,
163+
groups,
164+
)
165+
else:
166+
output_padding = (
167+
[0] + node.args[5] if num_args > 5 else [0, 0]
168+
)
169+
groups = node.args[6] if num_args > 6 else 1
170+
dilation = (
171+
[1] + node.args[7] if num_args > 7 else [1, 1]
172+
)
173+
conv_args = (
174+
qdq_node_after_unsqueeze,
175+
node.args[1],
176+
bias_node,
177+
stride,
178+
padding,
179+
output_padding,
180+
groups,
181+
dilation,
182+
)
183+
conv2d_node = graph.create_node(
184+
"call_function",
185+
self.conv1d_op_map[node.target],
186+
conv_args,
187+
)
188+
conv2d_node.meta = copy_meta(
189+
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
196190
)
197-
squeeze_node.meta = copy_meta(node.meta)
191+
qdq_node_after_conv2d = append_qdq(
192+
graph_module=graph_module,
193+
node=conv2d_node,
194+
qdq_node=list(node.users)[0],
195+
)
196+
197+
with graph_module.graph.inserting_after(
198+
qdq_node_after_conv2d
199+
):
200+
squeeze_op = torch.ops.aten.squeeze_copy.dims
201+
squeeze_node = graph.create_node(
202+
"call_function",
203+
squeeze_op,
204+
(
205+
qdq_node_after_conv2d,
206+
[2],
207+
),
208+
)
209+
squeeze_node.meta = copy_meta(node.meta)
198210

199211
for user in node.users.copy():
200212
user.replace_input_with(node, squeeze_node)

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,24 @@ def example_inputs(self):
490490
}
491491

492492

493+
class Conv1dBn(torch.nn.Module):
494+
def __init__(self, bias=True):
495+
super().__init__()
496+
self.conv = torch.nn.Conv1d(
497+
in_channels=2048,
498+
out_channels=2048,
499+
kernel_size=15,
500+
groups=2048,
501+
bias=bias,
502+
)
503+
self.batch_norm = torch.nn.BatchNorm1d(2048)
504+
505+
def forward(self, x):
506+
x = self.conv(x)
507+
x = self.batch_norm(x)
508+
return x
509+
510+
493511
class Conv1dSequential(torch.nn.Module):
494512
def __init__(self, bias=True):
495513
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,13 @@ def test_qnn_backend_conv1d(self):
377377
with self.subTest(i=i):
378378
self.lower_module_and_test_output(module, sample_input)
379379

380+
def test_qnn_conv1d_batch_norm(self):
381+
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
382+
sample_input = (torch.randn([1, 2048, 858]),)
383+
for i, module in enumerate(modules):
384+
with self.subTest(i=i):
385+
self.lower_module_and_test_output(module, sample_input)
386+
380387
def test_qnn_backend_conv2d(self):
381388
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
382389
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -2637,6 +2644,14 @@ def test_qnn_backend_conv1d(self):
26372644
module = self.get_qdq_module(module, sample_input)
26382645
self.lower_module_and_test_output(module, sample_input)
26392646

2647+
def test_qnn_conv1d_batch_norm(self):
2648+
modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405
2649+
sample_input = (torch.randn([1, 2048, 858]),)
2650+
for i, module in enumerate(modules):
2651+
with self.subTest(i=i):
2652+
module = self.get_qdq_module(module, sample_input)
2653+
self.lower_module_and_test_output(module, sample_input)
2654+
26402655
def test_qnn_backend_conv2d(self):
26412656
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
26422657
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -6870,13 +6885,30 @@ class MLLMSpecs:
68706885
tok_embedding_pte_size: float
68716886
decoder_pte_size: float
68726887

6888+
@dataclass(frozen=True)
6889+
class ALMSpecs(MLLMSpecs):
6890+
audio_path: str
6891+
golden_audio_feature: str
6892+
68736893
@dataclass(frozen=True)
68746894
class VLMSpecs(MLLMSpecs):
68756895
image_path: str
68766896
golden_image_feature: str
68776897

68786898
# TODO: refactor to support different backends
68796899
def setUp(self):
6900+
self.alm_specs = {
6901+
"granite_speech_3_3-2b": TestExampleMultimodalityScript.ALMSpecs(
6902+
max_seq_len=512,
6903+
sm8650_token_rate=5,
6904+
sm8750_token_rate=8,
6905+
encoder_pte_size=900_000_000, # 900MB
6906+
tok_embedding_pte_size=240_000_000, # 240MB
6907+
decoder_pte_size=3_000_000_000, # 3GB
6908+
audio_path="https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true", # Audio content: after his nap,...
6909+
golden_audio_feature="after his nap,",
6910+
),
6911+
}
68806912
self.vlm_specs = {
68816913
"smolvlm_500m_instruct": TestExampleMultimodalityScript.VLMSpecs(
68826914
max_seq_len=128,
@@ -6900,6 +6932,96 @@ def setUp(self):
69006932
),
69016933
}
69026934

6935+
def test_static_asr(self):
6936+
if not self.required_envs([self.model_name]):
6937+
self.skipTest("missing required envs")
6938+
6939+
if self.enable_x86_64:
6940+
# Running on host is extremely slow for large models, so we skip this check to avoid timeouts.
6941+
# Please verify the output on the actual device instead.
6942+
self.skipTest(
6943+
"Skipping the check for the static ASR model on x86 due to long execution time."
6944+
)
6945+
6946+
alm_specs: TestExampleMultimodalityScript.ALMSpecs = self.alm_specs[
6947+
self.model_name
6948+
]
6949+
prompt = "can you transcribe the speech into a written format?"
6950+
audio_path = alm_specs.audio_path
6951+
cmds = [
6952+
"python",
6953+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
6954+
"--artifact",
6955+
self.artifact_dir,
6956+
"--build_folder",
6957+
self.build_folder,
6958+
"--model",
6959+
self.model,
6960+
"--ip",
6961+
self.ip,
6962+
"--port",
6963+
str(self.port),
6964+
"--prompt",
6965+
prompt,
6966+
"--audio_path",
6967+
audio_path,
6968+
"--temperature",
6969+
"0",
6970+
"--decoder_model",
6971+
f"{self.model_name}",
6972+
"--model_mode",
6973+
"kv",
6974+
"--max_seq_len",
6975+
f"{alm_specs.max_seq_len}",
6976+
]
6977+
if self.compile_only:
6978+
cmds.extend(["--compile_only"])
6979+
elif self.device:
6980+
cmds.extend(["--device", self.device])
6981+
if self.host:
6982+
cmds.extend(["--host", self.host])
6983+
if self.pre_gen_pte:
6984+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
6985+
6986+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
6987+
with Listener((self.ip, self.port)) as listener:
6988+
conn = listener.accept()
6989+
p.communicate()
6990+
msg = json.loads(conn.recv())
6991+
if "Error" in msg:
6992+
self.fail(msg["Error"])
6993+
else:
6994+
if not self.compile_only:
6995+
model_out = msg["result"][0]
6996+
self.assertTrue(
6997+
alm_specs.golden_audio_feature in model_out.lower(),
6998+
f"Expected Output contains feature: '{alm_specs.golden_audio_feature}' Actual Output: '{model_out}'",
6999+
)
7000+
print(f"Audio Path: {audio_path}")
7001+
print(f"Query: {prompt}")
7002+
print(f"Answer: {model_out}")
7003+
7004+
encoder_pte_size = msg["audio_encoder_pte_size"]
7005+
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
7006+
decoder_pte_size = msg["pte_size"]
7007+
self.assertLessEqual(encoder_pte_size, alm_specs.encoder_pte_size)
7008+
self.assertLessEqual(
7009+
tok_embedding_pte_size, alm_specs.tok_embedding_pte_size
7010+
)
7011+
self.assertLessEqual(decoder_pte_size, alm_specs.decoder_pte_size)
7012+
print(f"Encoder PTE Size: {encoder_pte_size} bytes")
7013+
print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes")
7014+
print(f"Text Decoder PTE Size: {decoder_pte_size} bytes")
7015+
7016+
attr_name = f"{self.model.lower()}_token_rate"
7017+
if not self.compile_only and hasattr(alm_specs, attr_name):
7018+
device_inference_speed = msg["inference_speed"]
7019+
expected_inference_speed = getattr(alm_specs, attr_name)
7020+
print(f"Prompt Evaluation: {device_inference_speed} tokens/second")
7021+
self.assertGreaterEqual(
7022+
device_inference_speed, expected_inference_speed
7023+
)
7024+
69037025
def test_static_vlm(self):
69047026
if not self.required_envs([self.model_name]):
69057027
self.skipTest("missing required envs")
@@ -6964,7 +7086,7 @@ def test_static_vlm(self):
69647086
print(f"Query: {prompt}")
69657087
print(f"Answer: {model_out}")
69667088
if not self.enable_x86_64:
6967-
encoder_pte_size = msg["encoder_pte_size"]
7089+
encoder_pte_size = msg["vision_encoder_pte_size"]
69687090
tok_embedding_pte_size = msg["tok_embedding_pte_size"]
69697091
decoder_pte_size = msg["pte_size"]
69707092
self.assertLessEqual(encoder_pte_size, vlm_specs.encoder_pte_size)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
oncall("executorch")
5+
6+
fbcode_target(_kind = runtime.python_library,
7+
name = "granite_speech",
8+
srcs = [
9+
"__init__.py",
10+
"convert_weights.py",
11+
],
12+
_is_external_target = True,
13+
base_module = "executorch.examples.models.granite_speech",
14+
resources = {
15+
"config/2b_config.json": "config/2b_config.json",
16+
},
17+
deps = [
18+
"//caffe2:torch",
19+
"//executorch/examples/models/llama:llama2_model",
20+
"fbcode//pytorch/torchtune:lib",
21+
"fbsource//third-party/pypi/safetensors:safetensors",
22+
],
23+
visibility = ["PUBLIC"],
24+
)

0 commit comments

Comments
 (0)