Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backends/arm/_passes/rewrite_conv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ def _add_bias(
output_dtype = node.meta["val"].dtype
bias_data = torch.zeros(size=(output_channels,), dtype=output_dtype)

with graph_module.graph.inserting_after(weight_node):
# Constant placeholders must appear before user-input placeholders in
# the graph. Insert the synthetic bias at the first placeholder slot
# instead of near the conv node.
first_placeholder = next(
n for n in graph_module.graph.nodes if n.op == "placeholder"
)
with graph_module.graph.inserting_before(first_placeholder):
bias_node = create_constant_placeholder(
self.exported_program,
graph=graph_module.graph,
Expand Down
94 changes: 90 additions & 4 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@

from executorch.extension.llm.export.config.llm_config import LlmConfig

from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

input_t = Tuple[torch.Tensor]
input_th = Tuple[torch.Tensor, torch.Tensor]

# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
this_files_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -41,6 +45,22 @@
logger = logging.getLogger(__name__)


class HFPositionalAdapter(torch.nn.Module):
def __init__(self, exportable):
super().__init__()
self.inner = exportable

def forward(self, input_ids, cache_position):
# HF StaticCache eager path requires int64 index tensors, but keeping
# cache_position as int32 during export capture avoids adding an extra
# int64->int32 cast node in the lowered graph.
if torch._dynamo.is_compiling():
cp = cache_position
else:
cp = cache_position.to(torch.long)
return self.inner(input_ids=input_ids, cache_position=cp)


class TestLlama:
"""Test class of Llama models.

Expand All @@ -51,6 +71,44 @@ class TestLlama:

"""

def prepare_model_hf_static(self):
"""
Build a tiny HF LLaMA wrapped with TorchExportableModuleForDecoderOnlyLM (StaticCache)
See https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L214C17-L214C53
"""
# Tiny config
cfg = LlamaConfig(
vocab_size=32000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
use_cache=True,
)
base = LlamaForCausalLM(cfg).eval()

# REQUIRED: generation_config must request a 'static' cache with batch_size & max_cache_len
base.generation_config = GenerationConfig(
use_cache=True,
cache_implementation="static",
cache_config={"batch_size": 1, "max_cache_len": 128},
)

exportable = TorchExportableModuleForDecoderOnlyLM(
model=base, batch_size=1, max_cache_len=128
)

# Positional adapter so the pipeline can call module(*inputs)
model_for_pipeline = HFPositionalAdapter(exportable).eval()

# The tester will call model(*inputs). Provide (input_ids, cache_position)
input_ids = torch.tensor([[0]], dtype=torch.long) # shape [1, 1]
cache_position = torch.tensor([0], dtype=torch.int32) # shape [1]
inputs = (input_ids, cache_position)

return model_for_pipeline, inputs, None

def prepare_model(self):
checkpoint = None
params_file = None
Expand Down Expand Up @@ -86,13 +144,18 @@ def prepare_model(self):
# TODO: Enable key value cache
args = [
"--disable_dynamic_shape",
"--max_seq_length",
"4096",
"--max_context_length",
"4096",
"-c",
checkpoint,
"-p",
params_file,
"--model",
model_name,
]

parser = build_args_parser()
args = parser.parse_args(args)
llm_config = LlmConfig.from_args(args)
Expand Down Expand Up @@ -123,11 +186,10 @@ def test_llama_tosa_FP():
aten_op=[],
exir_op=[],
custom_path="llama_tosa_fb",
run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk
run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk
use_to_edge_transform_and_lower=True,
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
Comment on lines 188 to 191
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment says this test “Just want to write TOSA FB to disk”, but run_on_tosa_ref_model is now True (and the explicit serialize stage was removed). Either update the comment to match the new behavior (running the TOSA ref model) or set run_on_tosa_ref_model=False if the intent is still artifact-only.

Copilot uses AI. Check for mistakes.
)
pipeline.add_stage_after("to_executorch", pipeline.tester.serialize)
pipeline.run()


Expand All @@ -144,12 +206,36 @@ def test_llama_tosa_INT():
aten_op=[],
exir_op=[],
custom_path="llama_tosa_fb_int",
run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk
run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk
use_to_edge_transform_and_lower=True,
frobenius_threshold=None,
cosine_threshold=None,
Comment on lines 208 to 212
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment says this test “Just want to write TOSA FB to disk”, but run_on_tosa_ref_model is now True (and the explicit serialize stage was removed). Either update the comment to match the new behavior (running the TOSA ref model) or set run_on_tosa_ref_model=False if the intent is still artifact-only.

Copilot uses AI. Check for mistakes.
)
pipeline.add_stage_after("to_executorch", pipeline.tester.serialize)
pipeline.run()


def test_llama_tosa_INT_static():
llama_model, llama_inputs, _ = TestLlama().prepare_model_hf_static()
if llama_model is None or llama_inputs is None:
pytest.skip("Missing model and/or input files")

with torch.no_grad():
pipeline = TosaPipelineINT[input_th](
llama_model,
llama_inputs,
aten_op=[],
exir_op=[],
custom_path="llama_tosa_hf_static_int",
run_on_tosa_ref_model=True,
use_to_edge_transform_and_lower=True,
fold_quantize=False,
)
# NOTE: HF StaticCache INT currently keeps two delegated subgraphs
# after partitioning on this path, so expect two delegate calls in EXIR.
pipeline.change_args(
"check_count.exir",
{"torch.ops.higher_order.executorch_call_delegate": 2},
)
pipeline.run()


Expand Down
137 changes: 84 additions & 53 deletions backends/transforms/decompose_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,100 @@ def call(
graph_module.recompile()
return PassResult(graph_module, True)

@staticmethod
def _extract_input_tensors(node: torch.fx.Node) -> tuple[object, ...]:
def _extract_arg_value(arg):
if isinstance(arg, torch.fx.Node):
if "val" not in arg.meta:
raise RuntimeError(
f"Missing meta['val'] for SDPA arg node: {arg.name}"
)
return arg.meta["val"]
return arg

return tuple(_extract_arg_value(arg) for arg in node.args)

Comment on lines +45 to +57
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_extract_input_tensors only walks node.args and ignores node.kwargs. For aten.scaled_dot_product_attention it’s common for attn_mask / dropout_p / is_causal / scale to be provided as kwargs, so the make_fx trace here can silently use defaults and decompose the wrong computation. Consider canonicalizing the SDPA call into a full positional arg list (q,k,v,attn_mask,dropout_p,is_causal,scale) by merging args+kwargs+defaults, and use that both for tracing and for the later scale adjustment (including handling scale passed positionally).

Copilot uses AI. Check for mistakes.
@staticmethod
def _copy_decomposed_graph(
graph: torch.fx.Graph,
node: torch.fx.Node,
decomposed_module: torch.fx.GraphModule,
scale: object,
) -> None:
name_to_input_tensor_map = {}
for i, arg in enumerate(node.args):
name_to_input_tensor_map[f"arg{i}_1"] = arg

decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
last_decomposed_node = None
for decomposed_node in decomposed_module.graph.nodes:
if decomposed_node.op == "placeholder":
decomposed_node_to_subgraph_node[decomposed_node] = (
name_to_input_tensor_map[decomposed_node.name]
)

if decomposed_node.op == "output":
last_decomposed_node = decomposed_node.args[0]

for decomposed_node in decomposed_module.graph.nodes:
node.meta["nn_module_stack"] = decomposed_node.meta.get("nn_module_stack")
if decomposed_node.op == "placeholder":
continue
Comment on lines +80 to +83
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _copy_decomposed_graph, nn_module_stack metadata is being written onto the original SDPA node (which is erased) instead of propagating from the original node to the decomposed nodes / copied subgraph nodes. This likely drops nn_module_stack on the new nodes and breaks downstream tooling relying on that metadata. The direction should match other decomposition passes (e.g., set decomposed_node.meta["nn_module_stack"] = node.meta.get("nn_module_stack") before node_copy, or set it on subgraph_node after copying).

Copilot uses AI. Check for mistakes.

if decomposed_node.op == "output" and last_decomposed_node is not None:
for user in node.users.copy():
user.replace_input_with(
node,
decomposed_node_to_subgraph_node[last_decomposed_node],
)
continue

if scale is not None and decomposed_node.target in [
torch.ops.aten.mul.Scalar
]:
new_args = list(decomposed_node.args)
new_args[1] = math.sqrt(scale)
decomposed_node.args = tuple(new_args)

subgraph_node = graph.node_copy(
decomposed_node,
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
)
subgraph_node.meta["source_fn_stack"] = [
(subgraph_node, subgraph_node.target)
]
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node

def _decompose_sdpa_node(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
allow_non_fake_inputs: bool,
) -> None:
graph = graph_module.graph
input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes)

input_tensors = self._extract_input_tensors(node)
scale = node.kwargs.get("scale", None)

def _sdpa_with_gqa(*args, **kwargs):
# args: (q, k, v, [attn_mask, dropout_p, is_causal, scale])
q, k, v = args[:3]
# Shapes: (B, H, T, D)
Hq = q.shape[1]
Hk = k.shape[1]
if Hq != Hk:
# LLaMA-style GQA: tile K and V heads to match Q
assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}"
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare assert for the GQA head-ratio check makes this validation disappear under Python -O and can turn a shape mismatch into harder-to-debug downstream errors during tracing. Prefer raising a RuntimeError / ValueError with the same message so it is always enforced.

Suggested change
assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}"
if Hq % Hk != 0:
raise ValueError(f"GQA mismatch: Hq={Hq}, Hk={Hk}")

Copilot uses AI. Check for mistakes.
r = Hq // Hk
B, _, Tk, D = k.shape
k = k.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D)
v = v.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D)
args = (q, k, v) + tuple(args[3:])
return torch.ops.aten.scaled_dot_product_attention.default(*args, **kwargs)

# refer to pytorch/test/test_decomp.py
decomposed_module = make_fx(
node.target,
_sdpa_with_gqa,
decomposition_table=get_decompositions( # pyre-fixme[6]
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
Expand All @@ -65,56 +146,6 @@ def _decompose_sdpa_node(
)(*input_tensors)

with graph.inserting_before(node):
name_to_input_tensor_map = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a refactor?

for i, arg in enumerate(node.args):
name_to_input_tensor_map[f"arg{i}_1"] = arg

decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
last_decomposed_node = None
# Create a mapping from input nodes in decomposed module to original nodes.
# In decomposed module, there are only input tensors for placeholder op.
for decomposed_node in decomposed_module.graph.nodes:
if decomposed_node.op == "placeholder":
decomposed_node_to_subgraph_node[decomposed_node] = (
name_to_input_tensor_map[decomposed_node.name]
)

if decomposed_node.op == "output":
last_decomposed_node = decomposed_node.args[0]

# Copy node from decompose graph module
for decomposed_node in decomposed_module.graph.nodes:
node.meta["nn_module_stack"] = decomposed_node.meta.get(
"nn_module_stack"
)
if decomposed_node.op == "placeholder":
continue

if decomposed_node.op == "output" and last_decomposed_node is not None:
for user in node.users.copy():
user.replace_input_with(
node,
decomposed_node_to_subgraph_node[last_decomposed_node],
)
continue

if scale is not None and decomposed_node.target in [
torch.ops.aten.mul.Scalar
]:
new_args = list(decomposed_node.args)
# Based on the implementation of _scaled_dot_product_attention_math,
# the scale is applied to q and k before matmul.
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
new_args[1] = math.sqrt(scale)
decomposed_node.args = tuple(new_args)

subgraph_node = graph.node_copy(
decomposed_node,
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
)
subgraph_node.meta["source_fn_stack"] = [
(subgraph_node, subgraph_node.target)
]
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
self._copy_decomposed_graph(graph, node, decomposed_module, scale)

graph.erase_node(node)
Loading