[OpenVINO] Draft support of Gemma3n #1650
Draft
popovaan wants to merge 10 commits intohuggingface:mainfrom
Draft
[OpenVINO] Draft support of Gemma3n #1650popovaan wants to merge 10 commits intohuggingface:mainfrom
popovaan wants to merge 10 commits intohuggingface:mainfrom
Conversation
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
github-merge-queue Bot
pushed a commit
to openvinotoolkit/openvino
that referenced
this pull request
Apr 15, 2026
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot
pushed a commit
to openvinotoolkit/openvino
that referenced
this pull request
Apr 16, 2026
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot
pushed a commit
to openvinotoolkit/openvino
that referenced
this pull request
Apr 16, 2026
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot
pushed a commit
to openvinotoolkit/openvino
that referenced
this pull request
Apr 16, 2026
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot
pushed a commit
to openvinotoolkit/openvino
that referenced
this pull request
Apr 16, 2026
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
praasz
pushed a commit
to praasz/openvino
that referenced
this pull request
Apr 20, 2026
…kit#35323) ## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to openvinotoolkit#35260 openvinotoolkit#35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. openvinotoolkit#35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: openvinotoolkit#35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
sfatimar
pushed a commit
to sfatimar/openvino
that referenced
this pull request
Apr 23, 2026
…kit#35323) ## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to openvinotoolkit#35260 openvinotoolkit#35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. openvinotoolkit#35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: openvinotoolkit#35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes # (issue)
Before submitting