Arm backend: Add static cache integration test with llama#18404
Arm backend: Add static cache integration test with llama#18404xingguo01 wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18404
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 6 Pending, 3 Unrelated FailuresAs of commit c175435 with merge base e638059 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @digantdesai this PR touch code outside Arm backend and need a Meta review. Thanks! |
Change-Id: I881fa107f43c9682c18480d01996a5795ae7f086 Signed-off-by: Xingguo Li <xingguo.li@arm.com>
de5d980 to
c175435
Compare
There was a problem hiding this comment.
Pull request overview
This PR adds an Arm backend integration test that exercises HuggingFace LLaMA StaticCache lowering, and adjusts backend transforms/passes to better support LLaMA-style attention and graph-signature constraints during lowering.
Changes:
- Extend SDPA decomposition to handle LLaMA-style GQA (Q heads != KV heads) and refactor SDPA graph-copying helpers.
- Add a new HuggingFace StaticCache-based LLaMA INT TOSA integration test.
- Fix bias placeholder insertion ordering for rewritten convs to satisfy constant-vs-user-input placeholder constraints.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
backends/transforms/decompose_sdpa.py |
Refactors SDPA decomposition and adds a GQA wrapper for LLaMA-style head mismatch. |
backends/arm/test/models/test_llama.py |
Adds HF StaticCache LLaMA test and tweaks existing LLaMA TOSA pipeline settings. |
backends/arm/_passes/rewrite_conv_pass.py |
Ensures synthetic bias placeholders are inserted before user-input placeholders. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| custom_path="llama_tosa_fb_int", | ||
| run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk | ||
| run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk | ||
| use_to_edge_transform_and_lower=True, | ||
| frobenius_threshold=None, | ||
| cosine_threshold=None, |
There was a problem hiding this comment.
The inline comment says this test “Just want to write TOSA FB to disk”, but run_on_tosa_ref_model is now True (and the explicit serialize stage was removed). Either update the comment to match the new behavior (running the TOSA ref model) or set run_on_tosa_ref_model=False if the intent is still artifact-only.
| @staticmethod | ||
| def _extract_input_tensors(node: torch.fx.Node) -> tuple[object, ...]: | ||
| def _extract_arg_value(arg): | ||
| if isinstance(arg, torch.fx.Node): | ||
| if "val" not in arg.meta: | ||
| raise RuntimeError( | ||
| f"Missing meta['val'] for SDPA arg node: {arg.name}" | ||
| ) | ||
| return arg.meta["val"] | ||
| return arg | ||
|
|
||
| return tuple(_extract_arg_value(arg) for arg in node.args) | ||
|
|
There was a problem hiding this comment.
_extract_input_tensors only walks node.args and ignores node.kwargs. For aten.scaled_dot_product_attention it’s common for attn_mask / dropout_p / is_causal / scale to be provided as kwargs, so the make_fx trace here can silently use defaults and decompose the wrong computation. Consider canonicalizing the SDPA call into a full positional arg list (q,k,v,attn_mask,dropout_p,is_causal,scale) by merging args+kwargs+defaults, and use that both for tracing and for the later scale adjustment (including handling scale passed positionally).
| for decomposed_node in decomposed_module.graph.nodes: | ||
| node.meta["nn_module_stack"] = decomposed_node.meta.get("nn_module_stack") | ||
| if decomposed_node.op == "placeholder": | ||
| continue |
There was a problem hiding this comment.
In _copy_decomposed_graph, nn_module_stack metadata is being written onto the original SDPA node (which is erased) instead of propagating from the original node to the decomposed nodes / copied subgraph nodes. This likely drops nn_module_stack on the new nodes and breaks downstream tooling relying on that metadata. The direction should match other decomposition passes (e.g., set decomposed_node.meta["nn_module_stack"] = node.meta.get("nn_module_stack") before node_copy, or set it on subgraph_node after copying).
| Hk = k.shape[1] | ||
| if Hq != Hk: | ||
| # LLaMA-style GQA: tile K and V heads to match Q | ||
| assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}" |
There was a problem hiding this comment.
Using a bare assert for the GQA head-ratio check makes this validation disappear under Python -O and can turn a shape mismatch into harder-to-debug downstream errors during tracing. Prefer raising a RuntimeError / ValueError with the same message so it is always enforced.
| assert Hq % Hk == 0, f"GQA mismatch: Hq={Hq}, Hk={Hk}" | |
| if Hq % Hk != 0: | |
| raise ValueError(f"GQA mismatch: Hq={Hq}, Hk={Hk}") |
| custom_path="llama_tosa_fb", | ||
| run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk | ||
| run_on_tosa_ref_model=True, # Just want to write TOSA FB to disk | ||
| use_to_edge_transform_and_lower=True, | ||
| transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], |
There was a problem hiding this comment.
The inline comment says this test “Just want to write TOSA FB to disk”, but run_on_tosa_ref_model is now True (and the explicit serialize stage was removed). Either update the comment to match the new behavior (running the TOSA ref model) or set run_on_tosa_ref_model=False if the intent is still artifact-only.
| )(*input_tensors) | ||
|
|
||
| with graph.inserting_before(node): | ||
| name_to_input_tensor_map = {} |
There was a problem hiding this comment.
Is this just a refactor?
digantdesai
left a comment
There was a problem hiding this comment.
Does it work on the tosa ref model though? Just curious.
Add static cache integration tests in llama
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell