Skip to content

[OpenVINO] Draft support of Gemma3n #1650

Draft
popovaan wants to merge 10 commits intohuggingface:mainfrom
popovaan:gemma3n
Draft

[OpenVINO] Draft support of Gemma3n #1650
popovaan wants to merge 10 commits intohuggingface:mainfrom
popovaan:gemma3n

Conversation

@popovaan
Copy link
Copy Markdown
Collaborator

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

github-merge-queue Bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 15, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue Bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
praasz pushed a commit to praasz/openvino that referenced this pull request Apr 20, 2026
…kit#35323)

## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to openvinotoolkit#35260

openvinotoolkit#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. openvinotoolkit#35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: openvinotoolkit#35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
sfatimar pushed a commit to sfatimar/openvino that referenced this pull request Apr 23, 2026
…kit#35323)

## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to openvinotoolkit#35260

openvinotoolkit#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. openvinotoolkit#35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: openvinotoolkit#35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants