Skip to content

Commit 4403c28

Browse files
wwwindxingguo01
authored andcommitted
Arm backend: Add static cache integration test with llama
- Fix SDPA decomposition kwargs handling and Llama test comments - Canonicalize SDPA inputs from args, kwargs, and defaults before tracing so decomposition preserves attn_mask, dropout_p, is_causal, and scale correctly. - Also preserve nn_module_stack on copied nodes, replace the GQA assert with an explicit ValueError, add a regression test for the kwargs path, and remove stale comments in the ARM Llama TOSA tests. Change-Id: I881fa107f43c9682c18480d01996a5795ae7f086 Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 98a1d66 commit 4403c28

4 files changed

Lines changed: 302 additions & 59 deletions

File tree

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,13 @@ def _add_bias(
194194
output_dtype = node.meta["val"].dtype
195195
bias_data = torch.zeros(size=(output_channels,), dtype=output_dtype)
196196

197-
with graph_module.graph.inserting_after(weight_node):
197+
# Constant placeholders must appear before user-input placeholders in
198+
# the graph. Insert the synthetic bias at the first placeholder slot
199+
# instead of near the conv node.
200+
first_placeholder = next(
201+
n for n in graph_module.graph.nodes if n.op == "placeholder"
202+
)
203+
with graph_module.graph.inserting_before(first_placeholder):
198204
bias_node = create_constant_placeholder(
199205
self.exported_program,
200206
graph=graph_module.graph,

backends/arm/test/models/test_llama.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131

3232
from executorch.extension.llm.export.config.llm_config import LlmConfig
3333

34+
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM
35+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
36+
3437
input_t = Tuple[torch.Tensor]
38+
input_th = Tuple[torch.Tensor, torch.Tensor]
3539

3640
# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
3741
this_files_dir = os.path.dirname(os.path.abspath(__file__))
@@ -41,6 +45,22 @@
4145
logger = logging.getLogger(__name__)
4246

4347

48+
class HFPositionalAdapter(torch.nn.Module):
49+
def __init__(self, exportable):
50+
super().__init__()
51+
self.inner = exportable
52+
53+
def forward(self, input_ids, cache_position):
54+
# HF StaticCache eager path requires int64 index tensors, but keeping
55+
# cache_position as int32 during export capture avoids adding an extra
56+
# int64->int32 cast node in the lowered graph.
57+
if torch._dynamo.is_compiling():
58+
cp = cache_position
59+
else:
60+
cp = cache_position.to(torch.long)
61+
return self.inner(input_ids=input_ids, cache_position=cp)
62+
63+
4464
class TestLlama:
4565
"""Test class of Llama models.
4666
@@ -51,6 +71,44 @@ class TestLlama:
5171
5272
"""
5373

74+
def prepare_model_hf_static(self):
75+
"""
76+
Build a tiny HF LLaMA wrapped with TorchExportableModuleForDecoderOnlyLM (StaticCache)
77+
See https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L214C17-L214C53
78+
"""
79+
# Tiny config
80+
cfg = LlamaConfig(
81+
vocab_size=32000,
82+
hidden_size=256,
83+
intermediate_size=512,
84+
num_hidden_layers=2,
85+
num_attention_heads=4,
86+
num_key_value_heads=2,
87+
use_cache=True,
88+
)
89+
base = LlamaForCausalLM(cfg).eval()
90+
91+
# REQUIRED: generation_config must request a 'static' cache with batch_size & max_cache_len
92+
base.generation_config = GenerationConfig(
93+
use_cache=True,
94+
cache_implementation="static",
95+
cache_config={"batch_size": 1, "max_cache_len": 128},
96+
)
97+
98+
exportable = TorchExportableModuleForDecoderOnlyLM(
99+
model=base, batch_size=1, max_cache_len=128
100+
)
101+
102+
# Positional adapter so the pipeline can call module(*inputs)
103+
model_for_pipeline = HFPositionalAdapter(exportable).eval()
104+
105+
# The tester will call model(*inputs). Provide (input_ids, cache_position)
106+
input_ids = torch.tensor([[0]], dtype=torch.long) # shape [1, 1]
107+
cache_position = torch.tensor([0], dtype=torch.int32) # shape [1]
108+
inputs = (input_ids, cache_position)
109+
110+
return model_for_pipeline, inputs, None
111+
54112
def prepare_model(self):
55113
checkpoint = None
56114
params_file = None
@@ -86,13 +144,18 @@ def prepare_model(self):
86144
# TODO: Enable key value cache
87145
args = [
88146
"--disable_dynamic_shape",
147+
"--max_seq_length",
148+
"4096",
149+
"--max_context_length",
150+
"4096",
89151
"-c",
90152
checkpoint,
91153
"-p",
92154
params_file,
93155
"--model",
94156
model_name,
95157
]
158+
96159
parser = build_args_parser()
97160
args = parser.parse_args(args)
98161
llm_config = LlmConfig.from_args(args)
@@ -123,11 +186,10 @@ def test_llama_tosa_FP():
123186
aten_op=[],
124187
exir_op=[],
125188
custom_path="llama_tosa_fb",
126-
run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk
189+
run_on_tosa_ref_model=True,
127190
use_to_edge_transform_and_lower=True,
128191
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
129192
)
130-
pipeline.add_stage_after("to_executorch", pipeline.tester.serialize)
131193
pipeline.run()
132194

133195

@@ -144,12 +206,36 @@ def test_llama_tosa_INT():
144206
aten_op=[],
145207
exir_op=[],
146208
custom_path="llama_tosa_fb_int",
147-
run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk
209+
run_on_tosa_ref_model=True,
148210
use_to_edge_transform_and_lower=True,
149211
frobenius_threshold=None,
150212
cosine_threshold=None,
151213
)
152-
pipeline.add_stage_after("to_executorch", pipeline.tester.serialize)
214+
pipeline.run()
215+
216+
217+
def test_llama_tosa_INT_static():
218+
llama_model, llama_inputs, _ = TestLlama().prepare_model_hf_static()
219+
if llama_model is None or llama_inputs is None:
220+
pytest.skip("Missing model and/or input files")
221+
222+
with torch.no_grad():
223+
pipeline = TosaPipelineINT[input_th](
224+
llama_model,
225+
llama_inputs,
226+
aten_op=[],
227+
exir_op=[],
228+
custom_path="llama_tosa_hf_static_int",
229+
run_on_tosa_ref_model=True,
230+
use_to_edge_transform_and_lower=True,
231+
fold_quantize=True,
232+
)
233+
# NOTE: HF StaticCache INT currently keeps two delegated subgraphs
234+
# after partitioning on this path, so expect two delegate calls in EXIR.
235+
pipeline.change_args(
236+
"check_count.exir",
237+
{"torch.ops.higher_order.executorch_call_delegate": 2},
238+
)
153239
pipeline.run()
154240

155241

backends/transforms/decompose_sdpa.py

Lines changed: 129 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ class DecomposeScaledDotProductAttention(ExportPass):
2222
"""
2323

2424
_passes_required_after: Set[Type[ExportPass]] = set()
25+
_SDPA_OPTIONAL_ARGS = (
26+
("attn_mask", None),
27+
("dropout_p", 0.0),
28+
("is_causal", False),
29+
)
2530

2631
def __init__(self, allow_non_fake_inputs: bool = True) -> None:
2732
super().__init__()
@@ -42,19 +47,137 @@ def call(
4247
graph_module.recompile()
4348
return PassResult(graph_module, True)
4449

50+
@staticmethod
51+
def _extract_arg_value(arg: object) -> object:
52+
if isinstance(arg, torch.fx.Node):
53+
if "val" not in arg.meta:
54+
raise RuntimeError(f"Missing meta['val'] for SDPA arg node: {arg.name}")
55+
return arg.meta["val"]
56+
return arg
57+
58+
@classmethod
59+
def _canonicalize_sdpa_call(
60+
cls, node: torch.fx.Node
61+
) -> tuple[tuple[object, ...], object, object]:
62+
input_args = list(node.args)
63+
input_kwargs = dict(node.kwargs)
64+
65+
canonical_args = list(input_args[:3])
66+
for arg_index, (arg_name, default) in enumerate(
67+
cls._SDPA_OPTIONAL_ARGS, start=3
68+
):
69+
if len(input_args) > arg_index:
70+
canonical_args.append(input_args[arg_index])
71+
else:
72+
canonical_args.append(input_kwargs.pop(arg_name, default))
73+
74+
raw_scale = input_kwargs.pop("scale", None)
75+
canonical_args.append(raw_scale)
76+
scale = cls._extract_arg_value(raw_scale)
77+
enable_gqa = cls._extract_arg_value(input_kwargs.pop("enable_gqa", False))
78+
if input_kwargs:
79+
raise RuntimeError(
80+
"Unsupported kwargs for scaled_dot_product_attention: "
81+
f"{', '.join(sorted(input_kwargs.keys()))}"
82+
)
83+
84+
return tuple(canonical_args), scale, enable_gqa
85+
86+
@staticmethod
87+
def _copy_decomposed_graph(
88+
graph: torch.fx.Graph,
89+
node: torch.fx.Node,
90+
decomposed_module: torch.fx.GraphModule,
91+
canonical_inputs: tuple[object, ...],
92+
scale: object,
93+
) -> None:
94+
decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
95+
last_decomposed_node = None
96+
placeholder_nodes = [
97+
decomposed_node
98+
for decomposed_node in decomposed_module.graph.nodes
99+
if decomposed_node.op == "placeholder"
100+
]
101+
if len(placeholder_nodes) != len(canonical_inputs):
102+
raise RuntimeError(
103+
"Unexpected placeholder count when decomposing "
104+
"scaled_dot_product_attention"
105+
)
106+
for decomposed_node, arg in zip(placeholder_nodes, canonical_inputs):
107+
decomposed_node_to_subgraph_node[decomposed_node] = arg
108+
109+
for decomposed_node in decomposed_module.graph.nodes:
110+
if decomposed_node.op == "output":
111+
last_decomposed_node = decomposed_node.args[0]
112+
113+
for decomposed_node in decomposed_module.graph.nodes:
114+
decomposed_node.meta["nn_module_stack"] = node.meta.get("nn_module_stack")
115+
if decomposed_node.op == "placeholder":
116+
continue
117+
118+
if decomposed_node.op == "output" and last_decomposed_node is not None:
119+
for user in node.users.copy():
120+
user.replace_input_with(
121+
node,
122+
decomposed_node_to_subgraph_node[last_decomposed_node],
123+
)
124+
continue
125+
126+
if scale is not None and decomposed_node.target in [
127+
torch.ops.aten.mul.Scalar
128+
]:
129+
new_args = list(decomposed_node.args)
130+
new_args[1] = math.sqrt(scale)
131+
decomposed_node.args = tuple(new_args)
132+
133+
subgraph_node = graph.node_copy(
134+
decomposed_node,
135+
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
136+
)
137+
subgraph_node.meta["source_fn_stack"] = [
138+
(subgraph_node, subgraph_node.target)
139+
]
140+
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
141+
45142
def _decompose_sdpa_node(
46143
self,
47144
graph_module: torch.fx.GraphModule,
48145
node: torch.fx.Node,
49146
allow_non_fake_inputs: bool,
50147
) -> None:
51148
graph = graph_module.graph
52-
input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes)
53-
scale = node.kwargs.get("scale", None)
149+
150+
canonical_inputs, scale, enable_gqa = self._canonicalize_sdpa_call(node)
151+
input_tensors = tuple(self._extract_arg_value(arg) for arg in canonical_inputs)
152+
153+
def _sdpa_with_gqa(
154+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
155+
):
156+
# Shapes: (B, H, T, D)
157+
Hq = q.shape[1]
158+
Hk = k.shape[1]
159+
if Hq != Hk:
160+
# LLaMA-style GQA: tile K and V heads to match Q
161+
if Hq % Hk != 0:
162+
raise ValueError(f"GQA mismatch: Hq={Hq}, Hk={Hk}")
163+
r = Hq // Hk
164+
B, _, Tk, D = k.shape
165+
k = k.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D)
166+
v = v.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D)
167+
return torch.ops.aten.scaled_dot_product_attention.default(
168+
q,
169+
k,
170+
v,
171+
attn_mask,
172+
dropout_p,
173+
is_causal,
174+
scale=scale,
175+
enable_gqa=enable_gqa,
176+
)
54177

55178
# refer to pytorch/test/test_decomp.py
56179
decomposed_module = make_fx(
57-
node.target,
180+
_sdpa_with_gqa,
58181
decomposition_table=get_decompositions( # pyre-fixme[6]
59182
[
60183
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
@@ -65,56 +188,8 @@ def _decompose_sdpa_node(
65188
)(*input_tensors)
66189

67190
with graph.inserting_before(node):
68-
name_to_input_tensor_map = {}
69-
for i, arg in enumerate(node.args):
70-
name_to_input_tensor_map[f"arg{i}_1"] = arg
71-
72-
decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
73-
last_decomposed_node = None
74-
# Create a mapping from input nodes in decomposed module to original nodes.
75-
# In decomposed module, there are only input tensors for placeholder op.
76-
for decomposed_node in decomposed_module.graph.nodes:
77-
if decomposed_node.op == "placeholder":
78-
decomposed_node_to_subgraph_node[decomposed_node] = (
79-
name_to_input_tensor_map[decomposed_node.name]
80-
)
81-
82-
if decomposed_node.op == "output":
83-
last_decomposed_node = decomposed_node.args[0]
84-
85-
# Copy node from decompose graph module
86-
for decomposed_node in decomposed_module.graph.nodes:
87-
node.meta["nn_module_stack"] = decomposed_node.meta.get(
88-
"nn_module_stack"
89-
)
90-
if decomposed_node.op == "placeholder":
91-
continue
92-
93-
if decomposed_node.op == "output" and last_decomposed_node is not None:
94-
for user in node.users.copy():
95-
user.replace_input_with(
96-
node,
97-
decomposed_node_to_subgraph_node[last_decomposed_node],
98-
)
99-
continue
100-
101-
if scale is not None and decomposed_node.target in [
102-
torch.ops.aten.mul.Scalar
103-
]:
104-
new_args = list(decomposed_node.args)
105-
# Based on the implementation of _scaled_dot_product_attention_math,
106-
# the scale is applied to q and k before matmul.
107-
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
108-
new_args[1] = math.sqrt(scale)
109-
decomposed_node.args = tuple(new_args)
110-
111-
subgraph_node = graph.node_copy(
112-
decomposed_node,
113-
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
114-
)
115-
subgraph_node.meta["source_fn_stack"] = [
116-
(subgraph_node, subgraph_node.target)
117-
]
118-
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
191+
self._copy_decomposed_graph(
192+
graph, node, decomposed_module, canonical_inputs, scale
193+
)
119194

120195
graph.erase_node(node)

0 commit comments

Comments
 (0)