Skip to content

Commit 2bb459e

Browse files
evkotovmryzhov
andauthored
Skip StatefulSDPAFusion for models with shared KV-cache (#35323)
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
1 parent ef4c44b commit 2bb459e

2 files changed

Lines changed: 197 additions & 0 deletions

File tree

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cstdint>
1212
#include <memory>
1313
#include <tuple>
14+
#include <unordered_set>
1415
#include <vector>
1516

1617
#include "openvino/cc/pass/itt.hpp"
@@ -158,6 +159,32 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
158159
const auto sdp_node = ov::as_type_ptr<ov::op::v13::ScaledDotProductAttention>(root);
159160
const auto past_k_node = ov::as_type_ptr<ov::op::v6::ReadValue>(pattern_map.at(past_k).get_node_shared_ptr());
160161
const auto past_v_node = ov::as_type_ptr<ov::op::v6::ReadValue>(pattern_map.at(past_v).get_node_shared_ptr());
162+
// Skip SDPAs whose KV-cache Variable is shared with another SDPA: the fused
163+
// ScaledDotProductAttentionWithKVCache kernel does not support shared KV-cache and
164+
// partial fusion leaves the model in an inconsistent state.
165+
// Walk forward from past_k / past_v, stopping at SDPA boundaries, and count
166+
// how many SDPAs are reachable. Anchoring on the ReadValue (rather than on
167+
// direct input node pointers) is robust to intermediate Transpose/Reshape/
168+
// Convert/Gather/Broadcast ops between the cache source and the SDPA blocks.
169+
auto count_reachable_sdpas = [](ov::Node* start) {
170+
size_t cnt = 0;
171+
std::unordered_set<ov::Node*> visited;
172+
ov::op::util::visit_path_forward(
173+
start,
174+
visited,
175+
[](ov::Node*) {},
176+
[&](ov::Node* n) {
177+
if (ov::is_type<ov::op::v13::ScaledDotProductAttention>(n)) {
178+
++cnt;
179+
return true;
180+
}
181+
return false;
182+
});
183+
return cnt;
184+
};
185+
if (count_reachable_sdpas(past_k_node.get()) > 1 || count_reachable_sdpas(past_v_node.get()) > 1) {
186+
return false;
187+
}
161188
if (!check_valid_children_type(past_k_node) || !check_valid_children_type(past_v_node)) {
162189
return false;
163190
}

src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
#include "transformations/utils/print_model.hpp"
2020
#include "openvino/op/abs.hpp"
2121
#include "openvino/op/add.hpp"
22+
#include "openvino/op/assign.hpp"
2223
#include "openvino/op/broadcast.hpp"
2324
#include "openvino/op/concat.hpp"
25+
#include "openvino/op/convert.hpp"
2426
#include "openvino/op/gather.hpp"
27+
#include "openvino/op/read_value.hpp"
2528
#include "openvino/op/reshape.hpp"
2629
#include "openvino/op/scaled_dot_product_attention.hpp"
2730
#include "openvino/op/shape_of.hpp"
@@ -237,3 +240,170 @@ TEST(TransformationTests, StateConcatSDPAWithExtraNode) {
237240
}
238241
}
239242
}
243+
244+
// Build a model with two SDPA blocks sharing the same KV-cache Variables.
245+
// One ReadValue per Variable fans out to two independent Gather -> Concat -> SDPA paths,
246+
// mirroring the shared-KV-cache pattern seen in Gemma3n/Gemma4 exports.
247+
static std::shared_ptr<ov::Model> makeSharedKVModel(const ov::PartialShape& inputShape) {
248+
auto q1 = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
249+
auto k1 = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
250+
auto v1 = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
251+
auto q2 = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
252+
auto k2 = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
253+
auto v2 = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
254+
auto init_k = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
255+
auto init_v = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
256+
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(element::i32, ov::PartialShape{-1});
257+
258+
// Shared variables for KV-cache (single ReadValue per Variable, fanning out).
259+
auto var_k = std::make_shared<ov::op::util::Variable>(
260+
ov::op::util::VariableInfo{inputShape, element::f32, "shared_pastk"});
261+
auto var_v = std::make_shared<ov::op::util::Variable>(
262+
ov::op::util::VariableInfo{inputShape, element::f32, "shared_pastv"});
263+
auto pastk = std::make_shared<ov::op::v6::ReadValue>(init_k, var_k);
264+
auto pastv = std::make_shared<ov::op::v6::ReadValue>(init_v, var_v);
265+
266+
// Path 1
267+
auto gather_k1 = std::make_shared<ov::op::v8::Gather>(
268+
pastk, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
269+
auto gather_v1 = std::make_shared<ov::op::v8::Gather>(
270+
pastv, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
271+
auto concat_k1 = std::make_shared<ov::op::v0::Concat>(OutputVector{gather_k1, k1}, 2);
272+
auto concat_v1 = std::make_shared<ov::op::v0::Concat>(OutputVector{gather_v1, v1}, 2);
273+
auto sdpa1 = std::make_shared<ov::opset13::ScaledDotProductAttention>(q1, concat_k1, concat_v1, false);
274+
auto add1 = std::make_shared<op::v1::Add>(sdpa1, op::v0::Constant::create(element::f32, {1}, {1.0f}));
275+
276+
// Path 2 (fans out from the same ReadValue as path 1)
277+
auto gather_k2 = std::make_shared<ov::op::v8::Gather>(
278+
pastk, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
279+
auto gather_v2 = std::make_shared<ov::op::v8::Gather>(
280+
pastv, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
281+
auto concat_k2 = std::make_shared<ov::op::v0::Concat>(OutputVector{gather_k2, k2}, 2);
282+
auto concat_v2 = std::make_shared<ov::op::v0::Concat>(OutputVector{gather_v2, v2}, 2);
283+
auto sdpa2 = std::make_shared<ov::opset13::ScaledDotProductAttention>(q2, concat_k2, concat_v2, false);
284+
auto add2 = std::make_shared<op::v1::Add>(sdpa2, op::v0::Constant::create(element::f32, {1}, {1.0f}));
285+
286+
// Single pair of Assigns for the shared Variables (from path 1's Concat).
287+
auto assign_k = std::make_shared<op::v6::Assign>(concat_k1, var_k);
288+
auto assign_v = std::make_shared<op::v6::Assign>(concat_v1, var_v);
289+
290+
ResultVector results{std::make_shared<ov::op::v0::Result>(add1),
291+
std::make_shared<ov::op::v0::Result>(add2)};
292+
SinkVector sinks{assign_k, assign_v};
293+
return std::make_shared<Model>(results, sinks,
294+
ParameterVector{q1, k1, v1, q2, k2, v2, init_k, init_v, beam_idx}, "SharedKVModel");
295+
}
296+
297+
TEST_F(TransformationTestsF, StateConcatSDPASharedKVCache) {
298+
#if defined(OPENVINO_ARCH_X86_64) && (defined(__ANDROID__) || defined(ANDROID))
299+
test_skipped = true;
300+
GTEST_SKIP() << "Skipping StateConcatSDPASharedKVCache test on Android X64";
301+
#endif
302+
// When KV-cache is shared between multiple SDPA blocks, StatefulSDPAFusion must NOT apply:
303+
// leaving model_ref unset lets the fixture compare against a clone of the input model.
304+
auto inputShape = ov::PartialShape{-1, 8, -1, 64};
305+
model = makeSharedKVModel(inputShape);
306+
manager.register_pass<SDPASubgraphFusion>();
307+
}
308+
309+
// Mix of SDPAs in one model:
310+
// - "shared" part: one pair of Variables feeds two SDPAs (SDPA_s1, SDPA_s2)
311+
// - "exclusive" part: another pair of Variables feeds a single SDPA (SDPA_e)
312+
// After SDPASubgraphFusion, only SDPA_e must be fused to ScaledDotProductAttentionWithKVCache;
313+
// SDPA_s1 and SDPA_s2 must remain as plain ScaledDotProductAttention.
314+
static std::shared_ptr<ov::Model> makeMixedSharedAndExclusiveKVModel(const ov::PartialShape& inputShape,
315+
bool isRef = false) {
316+
auto make_param = [&](element::Type t, const ov::PartialShape& s) {
317+
return std::make_shared<ov::op::v0::Parameter>(t, s);
318+
};
319+
auto beam_idx = make_param(element::i32, ov::PartialShape{-1});
320+
321+
// Shared part
322+
auto q_s1 = make_param(element::f32, inputShape);
323+
auto k_s1 = make_param(element::f32, inputShape);
324+
auto v_s1 = make_param(element::f32, inputShape);
325+
auto q_s2 = make_param(element::f32, inputShape);
326+
auto k_s2 = make_param(element::f32, inputShape);
327+
auto v_s2 = make_param(element::f32, inputShape);
328+
auto init_ks = make_param(element::f32, inputShape);
329+
auto init_vs = make_param(element::f32, inputShape);
330+
auto var_ks = std::make_shared<ov::op::util::Variable>(
331+
ov::op::util::VariableInfo{inputShape, element::f32, "shared_pastk"});
332+
auto var_vs = std::make_shared<ov::op::util::Variable>(
333+
ov::op::util::VariableInfo{inputShape, element::f32, "shared_pastv"});
334+
// Single ReadValue per shared Variable, fanning out to two SDPA paths.
335+
auto rv_ks = std::make_shared<ov::op::v6::ReadValue>(init_ks, var_ks);
336+
auto rv_vs = std::make_shared<ov::op::v6::ReadValue>(init_vs, var_vs);
337+
338+
auto g_ks1 = std::make_shared<ov::op::v8::Gather>(rv_ks, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
339+
auto g_vs1 = std::make_shared<ov::op::v8::Gather>(rv_vs, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
340+
auto c_ks1 = std::make_shared<ov::op::v0::Concat>(OutputVector{g_ks1, k_s1}, 2);
341+
auto c_vs1 = std::make_shared<ov::op::v0::Concat>(OutputVector{g_vs1, v_s1}, 2);
342+
auto sdpa_s1 = std::make_shared<ov::opset13::ScaledDotProductAttention>(q_s1, c_ks1, c_vs1, false);
343+
auto add_s1 = std::make_shared<op::v1::Add>(sdpa_s1, op::v0::Constant::create(element::f32, {1}, {1.0f}));
344+
345+
auto g_ks2 = std::make_shared<ov::op::v8::Gather>(rv_ks, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
346+
auto g_vs2 = std::make_shared<ov::op::v8::Gather>(rv_vs, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
347+
auto c_ks2 = std::make_shared<ov::op::v0::Concat>(OutputVector{g_ks2, k_s2}, 2);
348+
auto c_vs2 = std::make_shared<ov::op::v0::Concat>(OutputVector{g_vs2, v_s2}, 2);
349+
auto sdpa_s2 = std::make_shared<ov::opset13::ScaledDotProductAttention>(q_s2, c_ks2, c_vs2, false);
350+
auto add_s2 = std::make_shared<op::v1::Add>(sdpa_s2, op::v0::Constant::create(element::f32, {1}, {1.0f}));
351+
352+
auto assign_ks = std::make_shared<op::v6::Assign>(c_ks1, var_ks);
353+
auto assign_vs = std::make_shared<op::v6::Assign>(c_vs1, var_vs);
354+
355+
// Exclusive part (single SDPA on its own Variables — eligible for fusion)
356+
auto q_e = make_param(element::f32, inputShape);
357+
auto k_e = make_param(element::f32, inputShape);
358+
auto v_e = make_param(element::f32, inputShape);
359+
auto init_ke = make_param(element::f32, inputShape);
360+
auto init_ve = make_param(element::f32, inputShape);
361+
auto var_ke = std::make_shared<ov::op::util::Variable>(
362+
ov::op::util::VariableInfo{inputShape, element::f32, "excl_pastk"});
363+
auto var_ve = std::make_shared<ov::op::util::Variable>(
364+
ov::op::util::VariableInfo{inputShape, element::f32, "excl_pastv"});
365+
auto rv_ke = std::make_shared<ov::op::v6::ReadValue>(init_ke, var_ke);
366+
auto rv_ve = std::make_shared<ov::op::v6::ReadValue>(init_ve, var_ve);
367+
Output<ov::Node> sdp_e, concat_ke, concat_ve;
368+
if (isRef) {
369+
ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config;
370+
config.fuse_concat = true;
371+
auto fused = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionWithKVCache>(
372+
OutputVector{q_e, k_e, v_e, beam_idx, rv_ke, rv_ve}, config);
373+
sdp_e = fused->output(0);
374+
concat_ke = fused->output(1);
375+
concat_ve = fused->output(2);
376+
} else {
377+
auto g_ke =
378+
std::make_shared<ov::op::v8::Gather>(rv_ke, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
379+
auto g_ve =
380+
std::make_shared<ov::op::v8::Gather>(rv_ve, beam_idx, op::v0::Constant::create(element::i32, {1}, {0}));
381+
concat_ke = std::make_shared<ov::op::v0::Concat>(OutputVector{g_ke, k_e}, 2);
382+
concat_ve = std::make_shared<ov::op::v0::Concat>(OutputVector{g_ve, v_e}, 2);
383+
sdp_e = std::make_shared<ov::opset13::ScaledDotProductAttention>(q_e, concat_ke, concat_ve, false);
384+
}
385+
auto add_e = std::make_shared<op::v1::Add>(sdp_e, op::v0::Constant::create(element::f32, {1}, {1.0f}));
386+
auto assign_ke = std::make_shared<op::v6::Assign>(concat_ke, var_ke);
387+
auto assign_ve = std::make_shared<op::v6::Assign>(concat_ve, var_ve);
388+
389+
ResultVector results{std::make_shared<ov::op::v0::Result>(add_s1),
390+
std::make_shared<ov::op::v0::Result>(add_s2),
391+
std::make_shared<ov::op::v0::Result>(add_e)};
392+
SinkVector sinks{assign_ks, assign_vs, assign_ke, assign_ve};
393+
ParameterVector params{q_s1, k_s1, v_s1, q_s2, k_s2, v_s2, init_ks, init_vs,
394+
q_e, k_e, v_e, init_ke, init_ve, beam_idx};
395+
return std::make_shared<Model>(results, sinks, params, "MixedSharedExclusiveKV");
396+
}
397+
398+
TEST_F(TransformationTestsF, StateConcatSDPAMixedSharedAndExclusive) {
399+
#if defined(OPENVINO_ARCH_X86_64) && (defined(__ANDROID__) || defined(ANDROID))
400+
test_skipped = true;
401+
GTEST_SKIP() << "Skipping StateConcatSDPAMixedSharedAndExclusive test on Android X64";
402+
#endif
403+
// Exclusive-cache SDPA must still be fused; shared-cache SDPAs must be left alone.
404+
auto inputShape = ov::PartialShape{-1, 8, -1, 64};
405+
model = makeMixedSharedAndExclusiveKVModel(inputShape);
406+
model_ref = makeMixedSharedAndExclusiveKVModel(inputShape, true);
407+
manager.register_pass<SDPASubgraphFusion>();
408+
}
409+

0 commit comments

Comments
 (0)