-
Notifications
You must be signed in to change notification settings - Fork 989
Arm backend: Add static cache integration test with llama #18404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,11 @@ | |
|
|
||
| from executorch.extension.llm.export.config.llm_config import LlmConfig | ||
|
|
||
| from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM | ||
| from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM | ||
|
|
||
| input_t = Tuple[torch.Tensor] | ||
| input_th = Tuple[torch.Tensor, torch.Tensor] | ||
|
|
||
| # Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py | ||
| this_files_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
|
@@ -41,6 +45,22 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class HFPositionalAdapter(torch.nn.Module): | ||
| def __init__(self, exportable): | ||
| super().__init__() | ||
| self.inner = exportable | ||
|
|
||
| def forward(self, input_ids, cache_position): | ||
| # HF StaticCache eager path requires int64 index tensors, but keeping | ||
| # cache_position as int32 during export capture avoids adding an extra | ||
| # int64->int32 cast node in the lowered graph. | ||
| if torch._dynamo.is_compiling(): | ||
| cp = cache_position | ||
| else: | ||
| cp = cache_position.to(torch.long) | ||
| return self.inner(input_ids=input_ids, cache_position=cp) | ||
|
|
||
|
|
||
| class TestLlama: | ||
| """Test class of Llama models. | ||
|
|
||
|
|
@@ -51,6 +71,44 @@ class TestLlama: | |
|
|
||
| """ | ||
|
|
||
| def prepare_model_hf_static(self): | ||
| """ | ||
| Build a tiny HF LLaMA wrapped with TorchExportableModuleForDecoderOnlyLM (StaticCache) | ||
| See https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L214C17-L214C53 | ||
| """ | ||
| # Tiny config | ||
| cfg = LlamaConfig( | ||
| vocab_size=32000, | ||
| hidden_size=256, | ||
| intermediate_size=512, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=2, | ||
| use_cache=True, | ||
| ) | ||
| base = LlamaForCausalLM(cfg).eval() | ||
|
|
||
| # REQUIRED: generation_config must request a 'static' cache with batch_size & max_cache_len | ||
| base.generation_config = GenerationConfig( | ||
| use_cache=True, | ||
| cache_implementation="static", | ||
| cache_config={"batch_size": 1, "max_cache_len": 128}, | ||
| ) | ||
|
|
||
| exportable = TorchExportableModuleForDecoderOnlyLM( | ||
| model=base, batch_size=1, max_cache_len=128 | ||
| ) | ||
|
|
||
| # Positional adapter so the pipeline can call module(*inputs) | ||
| model_for_pipeline = HFPositionalAdapter(exportable).eval() | ||
|
|
||
| # The tester will call model(*inputs). Provide (input_ids, cache_position) | ||
| input_ids = torch.tensor([[0]], dtype=torch.long) # shape [1, 1] | ||
| cache_position = torch.tensor([0], dtype=torch.int32) # shape [1] | ||
| inputs = (input_ids, cache_position) | ||
|
|
||
| return model_for_pipeline, inputs, None | ||
|
|
||
| def prepare_model(self): | ||
| checkpoint = None | ||
| params_file = None | ||
|
|
@@ -86,13 +144,18 @@ def prepare_model(self): | |
| # TODO: Enable key value cache | ||
| args = [ | ||
| "--disable_dynamic_shape", | ||
| "--max_seq_length", | ||
| "4096", | ||
| "--max_context_length", | ||
| "4096", | ||
| "-c", | ||
| checkpoint, | ||
| "-p", | ||
| params_file, | ||
| "--model", | ||
| model_name, | ||
| ] | ||
|
|
||
| parser = build_args_parser() | ||
| args = parser.parse_args(args) | ||
| llm_config = LlmConfig.from_args(args) | ||
|
|
@@ -123,11 +186,10 @@ def test_llama_tosa_FP(): | |
| aten_op=[], | ||
| exir_op=[], | ||
| custom_path="llama_tosa_fb", | ||
| run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk | ||
| run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk | ||
| use_to_edge_transform_and_lower=True, | ||
| transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], | ||
| ) | ||
| pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) | ||
| pipeline.run() | ||
|
|
||
|
|
||
|
|
@@ -144,12 +206,36 @@ def test_llama_tosa_INT(): | |
| aten_op=[], | ||
| exir_op=[], | ||
| custom_path="llama_tosa_fb_int", | ||
| run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk | ||
| run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk | ||
| use_to_edge_transform_and_lower=True, | ||
| frobenius_threshold=None, | ||
| cosine_threshold=None, | ||
|
Comment on lines
208
to
212
|
||
| ) | ||
| pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| def test_llama_tosa_INT_static(): | ||
| llama_model, llama_inputs, _ = TestLlama().prepare_model_hf_static() | ||
| if llama_model is None or llama_inputs is None: | ||
| pytest.skip("Missing model and/or input files") | ||
|
|
||
| with torch.no_grad(): | ||
| pipeline = TosaPipelineINT[input_th]( | ||
| llama_model, | ||
| llama_inputs, | ||
| aten_op=[], | ||
| exir_op=[], | ||
| custom_path="llama_tosa_hf_static_int", | ||
| run_on_tosa_ref_model=True, | ||
| use_to_edge_transform_and_lower=True, | ||
| fold_quantize=False, | ||
| ) | ||
| # NOTE: HF StaticCache INT currently keeps two delegated subgraphs | ||
| # after partitioning on this path, so expect two delegate calls in EXIR. | ||
| pipeline.change_args( | ||
| "check_count.exir", | ||
| {"torch.ops.higher_order.executorch_call_delegate": 2}, | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -42,19 +42,100 @@ def call( | |||||||
| graph_module.recompile() | ||||||||
| return PassResult(graph_module, True) | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _extract_input_tensors(node: torch.fx.Node) -> tuple[object, ...]: | ||||||||
| def _extract_arg_value(arg): | ||||||||
| if isinstance(arg, torch.fx.Node): | ||||||||
| if "val" not in arg.meta: | ||||||||
| raise RuntimeError( | ||||||||
| f"Missing meta['val'] for SDPA arg node: {arg.name}" | ||||||||
| ) | ||||||||
| return arg.meta["val"] | ||||||||
| return arg | ||||||||
|
|
||||||||
| return tuple(_extract_arg_value(arg) for arg in node.args) | ||||||||
|
|
||||||||
|
Comment on lines
+45
to
+57
|
||||||||
| @staticmethod | ||||||||
| def _copy_decomposed_graph( | ||||||||
| graph: torch.fx.Graph, | ||||||||
| node: torch.fx.Node, | ||||||||
| decomposed_module: torch.fx.GraphModule, | ||||||||
| scale: object, | ||||||||
| ) -> None: | ||||||||
| name_to_input_tensor_map = {} | ||||||||
| for i, arg in enumerate(node.args): | ||||||||
| name_to_input_tensor_map[f"arg{i}_1"] = arg | ||||||||
|
|
||||||||
| decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {} | ||||||||
| last_decomposed_node = None | ||||||||
| for decomposed_node in decomposed_module.graph.nodes: | ||||||||
| if decomposed_node.op == "placeholder": | ||||||||
| decomposed_node_to_subgraph_node[decomposed_node] = ( | ||||||||
| name_to_input_tensor_map[decomposed_node.name] | ||||||||
| ) | ||||||||
|
|
||||||||
| if decomposed_node.op == "output": | ||||||||
| last_decomposed_node = decomposed_node.args[0] | ||||||||
|
|
||||||||
| for decomposed_node in decomposed_module.graph.nodes: | ||||||||
| node.meta["nn_module_stack"] = decomposed_node.meta.get("nn_module_stack") | ||||||||
| if decomposed_node.op == "placeholder": | ||||||||
| continue | ||||||||
|
Comment on lines
+80
to
+83
|
||||||||
|
|
||||||||
| if decomposed_node.op == "output" and last_decomposed_node is not None: | ||||||||
| for user in node.users.copy(): | ||||||||
| user.replace_input_with( | ||||||||
| node, | ||||||||
| decomposed_node_to_subgraph_node[last_decomposed_node], | ||||||||
| ) | ||||||||
| continue | ||||||||
|
|
||||||||
| if scale is not None and decomposed_node.target in [ | ||||||||
| torch.ops.aten.mul.Scalar | ||||||||
| ]: | ||||||||
| new_args = list(decomposed_node.args) | ||||||||
| new_args[1] = math.sqrt(scale) | ||||||||
| decomposed_node.args = tuple(new_args) | ||||||||
|
|
||||||||
| subgraph_node = graph.node_copy( | ||||||||
| decomposed_node, | ||||||||
| arg_transform=lambda x: decomposed_node_to_subgraph_node[x], | ||||||||
| ) | ||||||||
| subgraph_node.meta["source_fn_stack"] = [ | ||||||||
| (subgraph_node, subgraph_node.target) | ||||||||
| ] | ||||||||
| decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node | ||||||||
|
|
||||||||
| def _decompose_sdpa_node( | ||||||||
| self, | ||||||||
| graph_module: torch.fx.GraphModule, | ||||||||
| node: torch.fx.Node, | ||||||||
| allow_non_fake_inputs: bool, | ||||||||
| ) -> None: | ||||||||
| graph = graph_module.graph | ||||||||
| input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes) | ||||||||
|
|
||||||||
| input_tensors = self._extract_input_tensors(node) | ||||||||
| scale = node.kwargs.get("scale", None) | ||||||||
|
|
||||||||
| def _sdpa_with_gqa(*args, **kwargs): | ||||||||
| # args: (q, k, v, [attn_mask, dropout_p, is_causal, scale]) | ||||||||
| q, k, v = args[:3] | ||||||||
| # Shapes: (B, H, T, D) | ||||||||
| Hq = q.shape[1] | ||||||||
| Hk = k.shape[1] | ||||||||
| if Hq != Hk: | ||||||||
| # LLaMA-style GQA: tile K and V heads to match Q | ||||||||
| assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}" | ||||||||
|
||||||||
| assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}" | |
| if Hq % Hk != 0: | |
| raise ValueError(f"GQA mismatch: Hq={Hq}, Hk={Hk}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just a refactor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inline comment says this test “Just want to write TOSA FB to disk”, but run_on_tosa_ref_model is now True (and the explicit serialize stage was removed). Either update the comment to match the new behavior (running the TOSA ref model) or set run_on_tosa_ref_model=False if the intent is still artifact-only.