Skip to content

Commit 91206c1

Browse files
wwwindxingguo01
authored andcommitted
Arm backend: Add static cache integration test with llama
Change-Id: I881fa107f43c9682c18480d01996a5795ae7f086 Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 31a15a4 commit 91206c1

3 files changed

Lines changed: 181 additions & 58 deletions

File tree

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ def _add_bias(
174174
output_dtype = node.meta["val"].dtype
175175
bias_data = torch.zeros(size=(output_channels,), dtype=output_dtype)
176176

177-
with graph_module.graph.inserting_after(weight_node):
177+
# Constant placeholders must appear before user-input placeholders in
178+
# the graph. Insert the synthetic bias at the first placeholder slot
179+
# instead of near the conv node.
180+
first_placeholder = next(
181+
n for n in graph_module.graph.nodes if n.op == "placeholder"
182+
)
183+
with graph_module.graph.inserting_before(first_placeholder):
178184
bias_node = create_constant_placeholder(
179185
self.exported_program,
180186
graph=graph_module.graph,

backends/arm/test/models/test_llama.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131

3232
from executorch.extension.llm.export.config.llm_config import LlmConfig
3333

34+
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM
35+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
36+
3437
input_t = Tuple[torch.Tensor]
38+
input_th = Tuple[torch.Tensor, torch.Tensor]
3539

3640
# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
3741
this_files_dir = os.path.dirname(os.path.abspath(__file__))
@@ -41,6 +45,22 @@
4145
logger = logging.getLogger(__name__)
4246

4347

48+
class HFPositionalAdapter(torch.nn.Module):
49+
def __init__(self, exportable):
50+
super().__init__()
51+
self.inner = exportable
52+
53+
def forward(self, input_ids, cache_position):
54+
# HF StaticCache eager path requires int64 index tensors, but keeping
55+
# cache_position as int32 during export capture avoids adding an extra
56+
# int64->int32 cast node in the lowered graph.
57+
if torch._dynamo.is_compiling():
58+
cp = cache_position
59+
else:
60+
cp = cache_position.to(torch.long)
61+
return self.inner(input_ids=input_ids, cache_position=cp)
62+
63+
4464
class TestLlama:
4565
"""Test class of Llama models.
4666
@@ -51,6 +71,44 @@ class TestLlama:
5171
5272
"""
5373

74+
def prepare_model_hf_static(self):
75+
"""
76+
Build a tiny HF LLaMA wrapped with TorchExportableModuleForDecoderOnlyLM (StaticCache)
77+
See https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L214C17-L214C53
78+
"""
79+
# Tiny config
80+
cfg = LlamaConfig(
81+
vocab_size=32000,
82+
hidden_size=256,
83+
intermediate_size=512,
84+
num_hidden_layers=2,
85+
num_attention_heads=4,
86+
num_key_value_heads=2,
87+
use_cache=True,
88+
)
89+
base = LlamaForCausalLM(cfg).eval()
90+
91+
# REQUIRED: generation_config must request a 'static' cache with batch_size & max_cache_len
92+
base.generation_config = GenerationConfig(
93+
use_cache=True,
94+
cache_implementation="static",
95+
cache_config={"batch_size": 1, "max_cache_len": 128},
96+
)
97+
98+
exportable = TorchExportableModuleForDecoderOnlyLM(
99+
model=base, batch_size=1, max_cache_len=128
100+
)
101+
102+
# Positional adapter so the pipeline can call module(*inputs)
103+
model_for_pipeline = HFPositionalAdapter(exportable).eval()
104+
105+
# The tester will call model(*inputs). Provide (input_ids, cache_position)
106+
input_ids = torch.tensor([[0]], dtype=torch.long) # shape [1, 1]
107+
cache_position = torch.tensor([0], dtype=torch.int32) # shape [1]
108+
inputs = (input_ids, cache_position)
109+
110+
return model_for_pipeline, inputs, None
111+
54112
def prepare_model(self):
55113
checkpoint = None
56114
params_file = None
@@ -86,13 +144,18 @@ def prepare_model(self):
86144
# TODO: Enable key value cache
87145
args = [
88146
"--disable_dynamic_shape",
147+
"--max_seq_length",
148+
"4096",
149+
"--max_context_length",
150+
"4096",
89151
"-c",
90152
checkpoint,
91153
"-p",
92154
params_file,
93155
"--model",
94156
model_name,
95157
]
158+
96159
parser = build_args_parser()
97160
args = parser.parse_args(args)
98161
llm_config = LlmConfig.from_args(args)
@@ -123,11 +186,10 @@ def test_llama_tosa_FP():
123186
aten_op=[],
124187
exir_op=[],
125188
custom_path="llama_tosa_fb",
126-
run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk
189+
run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk
127190
use_to_edge_transform_and_lower=True,
128191
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
129192
)
130-
pipeline.add_stage_after("to_executorch", pipeline.tester.serialize)
131193
pipeline.run()
132194

133195

@@ -144,12 +206,36 @@ def test_llama_tosa_INT():
144206
aten_op=[],
145207
exir_op=[],
146208
custom_path="llama_tosa_fb_int",
147-
run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk
209+
run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk
148210
use_to_edge_transform_and_lower=True,
149211
frobenius_threshold=None,
150212
cosine_threshold=None,
151213
)
152-
pipeline.add_stage_after("to_executorch", pipeline.tester.serialize)
214+
pipeline.run()
215+
216+
217+
def test_llama_tosa_INT_static():
218+
llama_model, llama_inputs, _ = TestLlama().prepare_model_hf_static()
219+
if llama_model is None or llama_inputs is None:
220+
pytest.skip("Missing model and/or input files")
221+
222+
with torch.no_grad():
223+
pipeline = TosaPipelineINT[input_th](
224+
llama_model,
225+
llama_inputs,
226+
aten_op=[],
227+
exir_op=[],
228+
custom_path="llama_tosa_hf_static_int",
229+
run_on_tosa_ref_model=True,
230+
use_to_edge_transform_and_lower=True,
231+
fold_quantize=False,
232+
)
233+
# NOTE: HF StaticCache INT currently keeps two delegated subgraphs
234+
# after partitioning on this path, so expect two delegate calls in EXIR.
235+
pipeline.change_args(
236+
"check_count.exir",
237+
{"torch.ops.higher_order.executorch_call_delegate": 2},
238+
)
153239
pipeline.run()
154240

155241

backends/transforms/decompose_sdpa.py

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,100 @@ def call(
4242
graph_module.recompile()
4343
return PassResult(graph_module, True)
4444

45+
@staticmethod
46+
def _extract_input_tensors(node: torch.fx.Node) -> tuple[object, ...]:
47+
def _extract_arg_value(arg):
48+
if isinstance(arg, torch.fx.Node):
49+
if "val" not in arg.meta:
50+
raise RuntimeError(
51+
f"Missing meta['val'] for SDPA arg node: {arg.name}"
52+
)
53+
return arg.meta["val"]
54+
return arg
55+
56+
return tuple(_extract_arg_value(arg) for arg in node.args)
57+
58+
@staticmethod
59+
def _copy_decomposed_graph(
60+
graph: torch.fx.Graph,
61+
node: torch.fx.Node,
62+
decomposed_module: torch.fx.GraphModule,
63+
scale: object,
64+
) -> None:
65+
name_to_input_tensor_map = {}
66+
for i, arg in enumerate(node.args):
67+
name_to_input_tensor_map[f"arg{i}_1"] = arg
68+
69+
decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
70+
last_decomposed_node = None
71+
for decomposed_node in decomposed_module.graph.nodes:
72+
if decomposed_node.op == "placeholder":
73+
decomposed_node_to_subgraph_node[decomposed_node] = (
74+
name_to_input_tensor_map[decomposed_node.name]
75+
)
76+
77+
if decomposed_node.op == "output":
78+
last_decomposed_node = decomposed_node.args[0]
79+
80+
for decomposed_node in decomposed_module.graph.nodes:
81+
node.meta["nn_module_stack"] = decomposed_node.meta.get("nn_module_stack")
82+
if decomposed_node.op == "placeholder":
83+
continue
84+
85+
if decomposed_node.op == "output" and last_decomposed_node is not None:
86+
for user in node.users.copy():
87+
user.replace_input_with(
88+
node,
89+
decomposed_node_to_subgraph_node[last_decomposed_node],
90+
)
91+
continue
92+
93+
if scale is not None and decomposed_node.target in [
94+
torch.ops.aten.mul.Scalar
95+
]:
96+
new_args = list(decomposed_node.args)
97+
new_args[1] = math.sqrt(scale)
98+
decomposed_node.args = tuple(new_args)
99+
100+
subgraph_node = graph.node_copy(
101+
decomposed_node,
102+
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
103+
)
104+
subgraph_node.meta["source_fn_stack"] = [
105+
(subgraph_node, subgraph_node.target)
106+
]
107+
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
108+
45109
def _decompose_sdpa_node(
46110
self,
47111
graph_module: torch.fx.GraphModule,
48112
node: torch.fx.Node,
49113
allow_non_fake_inputs: bool,
50114
) -> None:
51115
graph = graph_module.graph
52-
input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes)
116+
117+
input_tensors = self._extract_input_tensors(node)
53118
scale = node.kwargs.get("scale", None)
54119

120+
def _sdpa_with_gqa(*args, **kwargs):
121+
# args: (q, k, v, [attn_mask, dropout_p, is_causal, scale])
122+
q, k, v = args[:3]
123+
# Shapes: (B, H, T, D)
124+
Hq = q.shape[1]
125+
Hk = k.shape[1]
126+
if Hq != Hk:
127+
# LLaMA-style GQA: tile K and V heads to match Q
128+
assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}"
129+
r = Hq // Hk
130+
B, _, Tk, D = k.shape
131+
k = k.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D)
132+
v = v.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D)
133+
args = (q, k, v) + tuple(args[3:])
134+
return torch.ops.aten.scaled_dot_product_attention.default(*args, **kwargs)
135+
55136
# refer to pytorch/test/test_decomp.py
56137
decomposed_module = make_fx(
57-
node.target,
138+
_sdpa_with_gqa,
58139
decomposition_table=get_decompositions( # pyre-fixme[6]
59140
[
60141
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
@@ -65,56 +146,6 @@ def _decompose_sdpa_node(
65146
)(*input_tensors)
66147

67148
with graph.inserting_before(node):
68-
name_to_input_tensor_map = {}
69-
for i, arg in enumerate(node.args):
70-
name_to_input_tensor_map[f"arg{i}_1"] = arg
71-
72-
decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
73-
last_decomposed_node = None
74-
# Create a mapping from input nodes in decomposed module to original nodes.
75-
# In decomposed module, there are only input tensors for placeholder op.
76-
for decomposed_node in decomposed_module.graph.nodes:
77-
if decomposed_node.op == "placeholder":
78-
decomposed_node_to_subgraph_node[decomposed_node] = (
79-
name_to_input_tensor_map[decomposed_node.name]
80-
)
81-
82-
if decomposed_node.op == "output":
83-
last_decomposed_node = decomposed_node.args[0]
84-
85-
# Copy node from decompose graph module
86-
for decomposed_node in decomposed_module.graph.nodes:
87-
node.meta["nn_module_stack"] = decomposed_node.meta.get(
88-
"nn_module_stack"
89-
)
90-
if decomposed_node.op == "placeholder":
91-
continue
92-
93-
if decomposed_node.op == "output" and last_decomposed_node is not None:
94-
for user in node.users.copy():
95-
user.replace_input_with(
96-
node,
97-
decomposed_node_to_subgraph_node[last_decomposed_node],
98-
)
99-
continue
100-
101-
if scale is not None and decomposed_node.target in [
102-
torch.ops.aten.mul.Scalar
103-
]:
104-
new_args = list(decomposed_node.args)
105-
# Based on the implementation of _scaled_dot_product_attention_math,
106-
# the scale is applied to q and k before matmul.
107-
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
108-
new_args[1] = math.sqrt(scale)
109-
decomposed_node.args = tuple(new_args)
110-
111-
subgraph_node = graph.node_copy(
112-
decomposed_node,
113-
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
114-
)
115-
subgraph_node.meta["source_fn_stack"] = [
116-
(subgraph_node, subgraph_node.target)
117-
]
118-
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
149+
self._copy_decomposed_graph(graph, node, decomposed_module, scale)
119150

120151
graph.erase_node(node)

0 commit comments

Comments
 (0)