diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index 8244dc2558b..e253e80f145 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -194,7 +194,13 @@ def _add_bias( output_dtype = node.meta["val"].dtype bias_data = torch.zeros(size=(output_channels,), dtype=output_dtype) - with graph_module.graph.inserting_after(weight_node): + # Constant placeholders must appear before user-input placeholders in + # the graph. Insert the synthetic bias at the first placeholder slot + # instead of near the conv node. + first_placeholder = next( + n for n in graph_module.graph.nodes if n.op == "placeholder" + ) + with graph_module.graph.inserting_before(first_placeholder): bias_node = create_constant_placeholder( self.exported_program, graph=graph_module.graph, diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 704e3e07926..1602aa7b4ba 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -31,7 +31,11 @@ from executorch.extension.llm.export.config.llm_config import LlmConfig +from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM +from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + input_t = Tuple[torch.Tensor] +input_th = Tuple[torch.Tensor, torch.Tensor] # Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py this_files_dir = os.path.dirname(os.path.abspath(__file__)) @@ -41,6 +45,22 @@ logger = logging.getLogger(__name__) +class HFPositionalAdapter(torch.nn.Module): + def __init__(self, exportable): + super().__init__() + self.inner = exportable + + def forward(self, input_ids, cache_position): + # HF StaticCache eager path requires int64 index tensors, but keeping + # cache_position as int32 during export capture avoids adding an extra + # int64->int32 cast node in the lowered graph. + if torch._dynamo.is_compiling(): + cp = cache_position + else: + cp = cache_position.to(torch.long) + return self.inner(input_ids=input_ids, cache_position=cp) + + class TestLlama: """Test class of Llama models. @@ -51,6 +71,44 @@ class TestLlama: """ + def prepare_model_hf_static(self): + """ + Build a tiny HF LLaMA wrapped with TorchExportableModuleForDecoderOnlyLM (StaticCache) + See https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L214C17-L214C53 + """ + # Tiny config + cfg = LlamaConfig( + vocab_size=32000, + hidden_size=256, + intermediate_size=512, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + use_cache=True, + ) + base = LlamaForCausalLM(cfg).eval() + + # REQUIRED: generation_config must request a 'static' cache with batch_size & max_cache_len + base.generation_config = GenerationConfig( + use_cache=True, + cache_implementation="static", + cache_config={"batch_size": 1, "max_cache_len": 128}, + ) + + exportable = TorchExportableModuleForDecoderOnlyLM( + model=base, batch_size=1, max_cache_len=128 + ) + + # Positional adapter so the pipeline can call module(*inputs) + model_for_pipeline = HFPositionalAdapter(exportable).eval() + + # The tester will call model(*inputs). Provide (input_ids, cache_position) + input_ids = torch.tensor([[0]], dtype=torch.long) # shape [1, 1] + cache_position = torch.tensor([0], dtype=torch.int32) # shape [1] + inputs = (input_ids, cache_position) + + return model_for_pipeline, inputs, None + def prepare_model(self): checkpoint = None params_file = None @@ -86,6 +144,10 @@ def prepare_model(self): # TODO: Enable key value cache args = [ "--disable_dynamic_shape", + "--max_seq_length", + "4096", + "--max_context_length", + "4096", "-c", checkpoint, "-p", @@ -93,6 +155,7 @@ def prepare_model(self): "--model", model_name, ] + parser = build_args_parser() args = parser.parse_args(args) llm_config = LlmConfig.from_args(args) @@ -123,11 +186,10 @@ def test_llama_tosa_FP(): aten_op=[], exir_op=[], custom_path="llama_tosa_fb", - run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk + run_on_tosa_ref_model=True, use_to_edge_transform_and_lower=True, transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], ) - pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) pipeline.run() @@ -144,12 +206,36 @@ def test_llama_tosa_INT(): aten_op=[], exir_op=[], custom_path="llama_tosa_fb_int", - run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk + run_on_tosa_ref_model=True, use_to_edge_transform_and_lower=True, frobenius_threshold=None, cosine_threshold=None, ) - pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) + pipeline.run() + + +def test_llama_tosa_INT_static(): + llama_model, llama_inputs, _ = TestLlama().prepare_model_hf_static() + if llama_model is None or llama_inputs is None: + pytest.skip("Missing model and/or input files") + + with torch.no_grad(): + pipeline = TosaPipelineINT[input_th]( + llama_model, + llama_inputs, + aten_op=[], + exir_op=[], + custom_path="llama_tosa_hf_static_int", + run_on_tosa_ref_model=True, + use_to_edge_transform_and_lower=True, + fold_quantize=True, + ) + # NOTE: HF StaticCache INT currently keeps two delegated subgraphs + # after partitioning on this path, so expect two delegate calls in EXIR. + pipeline.change_args( + "check_count.exir", + {"torch.ops.higher_order.executorch_call_delegate": 2}, + ) pipeline.run() diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index 13acaa32f11..34afcffd8c5 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -22,6 +22,11 @@ class DecomposeScaledDotProductAttention(ExportPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + _SDPA_OPTIONAL_ARGS = ( + ("attn_mask", None), + ("dropout_p", 0.0), + ("is_causal", False), + ) def __init__(self, allow_non_fake_inputs: bool = True) -> None: super().__init__() @@ -42,6 +47,98 @@ def call( graph_module.recompile() return PassResult(graph_module, True) + @staticmethod + def _extract_arg_value(arg: object) -> object: + if isinstance(arg, torch.fx.Node): + if "val" not in arg.meta: + raise RuntimeError(f"Missing meta['val'] for SDPA arg node: {arg.name}") + return arg.meta["val"] + return arg + + @classmethod + def _canonicalize_sdpa_call( + cls, node: torch.fx.Node + ) -> tuple[tuple[object, ...], object, object]: + input_args = list(node.args) + input_kwargs = dict(node.kwargs) + + canonical_args = list(input_args[:3]) + for arg_index, (arg_name, default) in enumerate( + cls._SDPA_OPTIONAL_ARGS, start=3 + ): + if len(input_args) > arg_index: + canonical_args.append(input_args[arg_index]) + else: + canonical_args.append(input_kwargs.pop(arg_name, default)) + + raw_scale = input_kwargs.pop("scale", None) + canonical_args.append(raw_scale) + scale = cls._extract_arg_value(raw_scale) + enable_gqa = cls._extract_arg_value(input_kwargs.pop("enable_gqa", False)) + if input_kwargs: + raise RuntimeError( + "Unsupported kwargs for scaled_dot_product_attention: " + f"{', '.join(sorted(input_kwargs.keys()))}" + ) + + return tuple(canonical_args), scale, enable_gqa + + @staticmethod + def _copy_decomposed_graph( + graph: torch.fx.Graph, + node: torch.fx.Node, + decomposed_module: torch.fx.GraphModule, + canonical_inputs: tuple[object, ...], + scale: object, + ) -> None: + decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {} + last_decomposed_node = None + placeholder_nodes = [ + decomposed_node + for decomposed_node in decomposed_module.graph.nodes + if decomposed_node.op == "placeholder" + ] + if len(placeholder_nodes) != len(canonical_inputs): + raise RuntimeError( + "Unexpected placeholder count when decomposing " + "scaled_dot_product_attention" + ) + for decomposed_node, arg in zip(placeholder_nodes, canonical_inputs): + decomposed_node_to_subgraph_node[decomposed_node] = arg + + for decomposed_node in decomposed_module.graph.nodes: + if decomposed_node.op == "output": + last_decomposed_node = decomposed_node.args[0] + + for decomposed_node in decomposed_module.graph.nodes: + decomposed_node.meta["nn_module_stack"] = node.meta.get("nn_module_stack") + if decomposed_node.op == "placeholder": + continue + + if decomposed_node.op == "output" and last_decomposed_node is not None: + for user in node.users.copy(): + user.replace_input_with( + node, + decomposed_node_to_subgraph_node[last_decomposed_node], + ) + continue + + if scale is not None and decomposed_node.target in [ + torch.ops.aten.mul.Scalar + ]: + new_args = list(decomposed_node.args) + new_args[1] = math.sqrt(scale) + decomposed_node.args = tuple(new_args) + + subgraph_node = graph.node_copy( + decomposed_node, + arg_transform=lambda x: decomposed_node_to_subgraph_node[x], + ) + subgraph_node.meta["source_fn_stack"] = [ + (subgraph_node, subgraph_node.target) + ] + decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node + def _decompose_sdpa_node( self, graph_module: torch.fx.GraphModule, @@ -49,12 +146,38 @@ def _decompose_sdpa_node( allow_non_fake_inputs: bool, ) -> None: graph = graph_module.graph - input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes) - scale = node.kwargs.get("scale", None) + + canonical_inputs, scale, enable_gqa = self._canonicalize_sdpa_call(node) + input_tensors = tuple(self._extract_arg_value(arg) for arg in canonical_inputs) + + def _sdpa_with_gqa( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None + ): + # Shapes: (B, H, T, D) + Hq = q.shape[1] + Hk = k.shape[1] + if Hq != Hk: + # LLaMA-style GQA: tile K and V heads to match Q + if Hq % Hk != 0: + raise ValueError(f"GQA mismatch: Hq={Hq}, Hk={Hk}") + r = Hq // Hk + B, _, Tk, D = k.shape + k = k.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D) + v = v.unsqueeze(2).expand(B, Hk, r, Tk, D).reshape(B, Hq, Tk, D) + return torch.ops.aten.scaled_dot_product_attention.default( + q, + k, + v, + attn_mask, + dropout_p, + is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) # refer to pytorch/test/test_decomp.py decomposed_module = make_fx( - node.target, + _sdpa_with_gqa, decomposition_table=get_decompositions( # pyre-fixme[6] [ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, @@ -65,56 +188,8 @@ def _decompose_sdpa_node( )(*input_tensors) with graph.inserting_before(node): - name_to_input_tensor_map = {} - for i, arg in enumerate(node.args): - name_to_input_tensor_map[f"arg{i}_1"] = arg - - decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {} - last_decomposed_node = None - # Create a mapping from input nodes in decomposed module to original nodes. - # In decomposed module, there are only input tensors for placeholder op. - for decomposed_node in decomposed_module.graph.nodes: - if decomposed_node.op == "placeholder": - decomposed_node_to_subgraph_node[decomposed_node] = ( - name_to_input_tensor_map[decomposed_node.name] - ) - - if decomposed_node.op == "output": - last_decomposed_node = decomposed_node.args[0] - - # Copy node from decompose graph module - for decomposed_node in decomposed_module.graph.nodes: - node.meta["nn_module_stack"] = decomposed_node.meta.get( - "nn_module_stack" - ) - if decomposed_node.op == "placeholder": - continue - - if decomposed_node.op == "output" and last_decomposed_node is not None: - for user in node.users.copy(): - user.replace_input_with( - node, - decomposed_node_to_subgraph_node[last_decomposed_node], - ) - continue - - if scale is not None and decomposed_node.target in [ - torch.ops.aten.mul.Scalar - ]: - new_args = list(decomposed_node.args) - # Based on the implementation of _scaled_dot_product_attention_math, - # the scale is applied to q and k before matmul. - # refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873 - new_args[1] = math.sqrt(scale) - decomposed_node.args = tuple(new_args) - - subgraph_node = graph.node_copy( - decomposed_node, - arg_transform=lambda x: decomposed_node_to_subgraph_node[x], - ) - subgraph_node.meta["source_fn_stack"] = [ - (subgraph_node, subgraph_node.target) - ] - decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node + self._copy_decomposed_graph( + graph, node, decomposed_module, canonical_inputs, scale + ) graph.erase_node(node) diff --git a/backends/transforms/test/test_decompose_sdpa.py b/backends/transforms/test/test_decompose_sdpa.py new file mode 100644 index 00000000000..5f486000bd4 --- /dev/null +++ b/backends/transforms/test/test_decompose_sdpa.py @@ -0,0 +1,76 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) +from torch.export import export + + +class TestDecomposeScaledDotProductAttention(unittest.TestCase): + def test_decompose_sdpa_preserves_kwargs(self) -> None: + class Block(torch.nn.Module): + def forward(self, q, k, v, mask): + return torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + scale=0.25, + ) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.block = Block() + + def forward(self, q, k, v, mask): + return self.block(q, k, v, mask) + + q = torch.randn(1, 2, 3, 4) + k = torch.randn(1, 2, 3, 4) + v = torch.randn(1, 2, 3, 4) + mask = torch.tensor( + [[[[True, False, True], [True, True, False], [False, True, True]]]] + ) + + graph_module = export(Model().eval(), (q, k, v, mask), strict=True).module() + + before_output = graph_module(q, k, v, mask) + original_nn_module_stack = None + self.assertTrue( + any( + node.target == torch.ops.aten.scaled_dot_product_attention.default + for node in graph_module.graph.nodes + if node.op == "call_function" + ) + ) + for node in graph_module.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.ops.aten.scaled_dot_product_attention.default + ): + original_nn_module_stack = node.meta.get("nn_module_stack") + break + + self.assertIsNotNone(original_nn_module_stack) + + DecomposeScaledDotProductAttention()(graph_module) + + self.assertFalse( + any( + node.target == torch.ops.aten.scaled_dot_product_attention.default + for node in graph_module.graph.nodes + if node.op == "call_function" + ) + ) + for node in graph_module.graph.nodes: + if node.op == "call_function": + self.assertEqual( + node.meta.get("nn_module_stack"), original_nn_module_stack + ) + torch.testing.assert_close(graph_module(q, k, v, mask), before_output)