diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index 51ac026c7c5060..710dfa6d1c85f6 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -700,10 +700,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format if (reshape_mode == reshape::reshape_mode::base) { if (crop_axis == 0 && !crop_layout.get_partial_shape()[0].is_dynamic() && crop_layout.get_partial_shape()[0].get_length() == 1 && - !(user_info.second.get_partial_shape()[0].is_static() && - user_info.second.get_partial_shape()[0].get_length() == 1)) { + !reshape_desc->output_pattern.empty() && + reshape_desc->output_pattern[0] != 0 && reshape_desc->output_pattern[0] != 1) { // The crop produces exactly batch=1 per slice and the reshape squeezes that dim. - // The reshape absorbs that dim, so the padding axis in the output remains 0. + // output_pattern[0] == -1 means the batch dim is absorbed (squeezed). reshape_axis = 0; } else { auto mul = 1; @@ -767,7 +767,8 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format padding::DynamicDimsMask reshape_dyn_pad_mask; if (crop_axis == 0 && crop_dim_val == 1 && - !(reshape_ps[0].is_static() && reshape_ps[0].get_length() == 1)) { + !reshape_desc->output_pattern.empty() && + reshape_desc->output_pattern[0] != 0 && reshape_desc->output_pattern[0] != 1) { // The crop splits on the batch axis with exactly batch=1 per slice // and the reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x]. // Padding offsets are in units of one 4D batch slice (= f*y*x elements), diff --git a/src/plugins/intel_gpu/src/graph/include/reshape_inst.h b/src/plugins/intel_gpu/src/graph/include/reshape_inst.h index 36a5ffcf72960a..bb2e5978ea903e 100644 --- a/src/plugins/intel_gpu/src/graph/include/reshape_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/reshape_inst.h @@ -74,10 +74,13 @@ struct typed_program_node : public typed_program_node_base { return false; if (input_pshape[0].get_length() != 1) return false; - // Reject if the reshape just flattens spatial dims while keeping batch=1 - // (e.g. [1,C,H,W] -> [1,C,H*W]). Only allow when the batch dim is truly squeezed. - auto& out_ps = prim->output_partial_shape; - if (!out_ps[0].is_dynamic() && out_ps[0].get_length() == 1) + // Reject if the reshape preserves the batch=1 dim (spatial flatten, not batch squeeze). + // output_pattern[0] == -1 means the first dim is inferred (batch absorbed/squeezed). + // output_pattern[0] == 0 or 1 means batch=1 is explicitly kept. + // e.g. [1,C,H,W] -> [1,C,H*W] has pattern [1,-1] => reject + // [1,N,H,W,C] -> [N,H,W,C] has pattern [-1,H,W,C] => allow + auto first_out_pattern = prim->output_pattern[0]; + if (first_out_pattern == 0 || first_out_pattern == 1) return false; return true; } diff --git a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp index c48040eb8bc329..b6920e8a4841cf 100644 --- a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp @@ -1911,3 +1911,87 @@ TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_spatial_ for (size_t i = 0; i < slice_elems; i++) ASSERT_FLOAT_EQ(out1[i], input_data[1 * slice_elems + i]) << "Branch 1 mismatch at " << i; } + +// SwinTransformer layers.3 pattern: crop [1,1,H,W,C] → Reshape [-1,H,W,C] (base mode) +// Runtime reshape output[0] == 1, but output_pattern[0] == -1 (batch squeeze). +// In-place crop must still work correctly. +TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_reshape_window_count_one) { + auto& engine = get_test_engine(); + tests::random_generator rg(GET_SUITE_NAME); + + // 3-way QKV split on axis 0 with f=1 (like SwinTransformer window_count=1) + const size_t num_heads = 3, seq_len = 4, head_dim = 2; + const size_t slice_elems = 1 * num_heads * seq_len * head_dim; + + auto in_layout = layout{ov::PartialShape{3, -1, static_cast(num_heads), + static_cast(seq_len), static_cast(head_dim)}, + data_types::f32, format::bfzyx}; + auto input_mem = engine.allocate_memory({{3, 1, static_cast(num_heads), + static_cast(seq_len), static_cast(head_dim)}, + data_types::f32, format::bfzyx}); + auto axis_mem = engine.allocate_memory({{}, data_types::i64, format::bfyx}); + auto splits_length_mem = engine.allocate_memory({{3}, data_types::i64, format::bfyx}); + + const int64_t axis = 0; + + auto input_data = rg.generate_random_1d(input_mem->count(), -1.f, 1.f); + set_values(input_mem, input_data); + set_values(axis_mem, {axis}); + set_values(splits_length_mem, {1, 1, 1}); + + // output_pattern[0] == -1: batch dim is squeezed. Runtime output[0] resolves to 1. + const std::vector squeeze_pattern = {-1, static_cast(num_heads), + static_cast(seq_len), static_cast(head_dim)}; + const ov::PartialShape squeeze_out_shape = {-1, static_cast(num_heads), + static_cast(seq_len), static_cast(head_dim)}; + + cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split; + topology topology( + input_layout("input", in_layout), + data("axis", axis_mem), + data("splits_length", splits_length_mem), + // Q branch + crop("crop_q", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 0, axis), + reshape("reshape_q", input_info("crop_q"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base), + reorder("output_q", input_info("reshape_q"), format::bfyx, data_types::f32, std::vector(), reorder_mean_mode::subtract, padding(), true), + // K branch + crop("crop_k", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 1, axis), + reshape("reshape_k", input_info("crop_k"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base), + reorder("output_k", input_info("reshape_k"), format::bfyx, data_types::f32, std::vector(), reorder_mean_mode::subtract, padding(), true), + // V branch + crop("crop_v", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 2, axis), + reshape("reshape_v", input_info("crop_v"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base), + reorder("output_v", input_info("reshape_v"), format::bfyx, data_types::f32, std::vector(), reorder_mean_mode::subtract, padding(), true) + ); + + auto config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + config.set_property(ov::intel_gpu::optimize_data(true)); + network network(engine, topology, config); + network.set_input_data("input", input_mem); + + auto outputs = network.execute(); + + // Crops must still be in-place optimized despite runtime reshape output[0] == 1 + ASSERT_TRUE(network.get_primitive("crop_q")->can_be_optimized()); + ASSERT_TRUE(network.get_primitive("crop_k")->can_be_optimized()); + ASSERT_TRUE(network.get_primitive("crop_v")->can_be_optimized()); + + auto q_mem = outputs.at("output_q").get_memory(); + cldnn::mem_lock q_out(q_mem, get_test_stream()); + auto k_mem = outputs.at("output_k").get_memory(); + cldnn::mem_lock k_out(k_mem, get_test_stream()); + auto v_mem = outputs.at("output_v").get_memory(); + cldnn::mem_lock v_out(v_mem, get_test_stream()); + + ASSERT_EQ(q_out.size(), slice_elems); + ASSERT_EQ(k_out.size(), slice_elems); + ASSERT_EQ(v_out.size(), slice_elems); + + for (size_t i = 0; i < slice_elems; i++) + ASSERT_FLOAT_EQ(q_out[i], input_data[0 * slice_elems + i]) << "Q mismatch at " << i; + for (size_t i = 0; i < slice_elems; i++) + ASSERT_FLOAT_EQ(k_out[i], input_data[1 * slice_elems + i]) << "K mismatch at " << i; + for (size_t i = 0; i < slice_elems; i++) + ASSERT_FLOAT_EQ(v_out[i], input_data[2 * slice_elems + i]) << "V mismatch at " << i; +}