Skip to content

Commit 3687ccc

Browse files
26.1.1 hotfix (#35170)
### Details: - *Cherry pick a bug fixing from master branch* - *commit ID: 2f0404d* ### Tickets: - *CVS-184283* ### AI Assistance: - *AI assistance used: yes* - *If yes, summarize how AI was used and what human validation was performed (build/tests/manual checks).* Use the github agent to cherry pick commit from master branch Co-authored-by: Andrew Kwangwoong Park <andrew.park@intel.com> Co-authored-by: peterchen-intel <19401820+peterchen-intel@users.noreply.github.com>
1 parent 53fb74f commit 3687ccc

3 files changed

Lines changed: 96 additions & 8 deletions

File tree

src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
700700
if (reshape_mode == reshape::reshape_mode::base) {
701701
if (crop_axis == 0 && !crop_layout.get_partial_shape()[0].is_dynamic() &&
702702
crop_layout.get_partial_shape()[0].get_length() == 1 &&
703-
!(user_info.second.get_partial_shape()[0].is_static() &&
704-
user_info.second.get_partial_shape()[0].get_length() == 1)) {
703+
!reshape_desc->output_pattern.empty() &&
704+
reshape_desc->output_pattern[0] != 0 && reshape_desc->output_pattern[0] != 1) {
705705
// The crop produces exactly batch=1 per slice and the reshape squeezes that dim.
706-
// The reshape absorbs that dim, so the padding axis in the output remains 0.
706+
// output_pattern[0] == -1 means the batch dim is absorbed (squeezed).
707707
reshape_axis = 0;
708708
} else {
709709
auto mul = 1;
@@ -767,7 +767,8 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
767767
padding::DynamicDimsMask reshape_dyn_pad_mask;
768768

769769
if (crop_axis == 0 && crop_dim_val == 1 &&
770-
!(reshape_ps[0].is_static() && reshape_ps[0].get_length() == 1)) {
770+
!reshape_desc->output_pattern.empty() &&
771+
reshape_desc->output_pattern[0] != 0 && reshape_desc->output_pattern[0] != 1) {
771772
// The crop splits on the batch axis with exactly batch=1 per slice
772773
// and the reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x].
773774
// Padding offsets are in units of one 4D batch slice (= f*y*x elements),

src/plugins/intel_gpu/src/graph/include/reshape_inst.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,13 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
7474
return false;
7575
if (input_pshape[0].get_length() != 1)
7676
return false;
77-
// Reject if the reshape just flattens spatial dims while keeping batch=1
78-
// (e.g. [1,C,H,W] -> [1,C,H*W]). Only allow when the batch dim is truly squeezed.
79-
auto& out_ps = prim->output_partial_shape;
80-
if (!out_ps[0].is_dynamic() && out_ps[0].get_length() == 1)
77+
// Reject if the reshape preserves the batch=1 dim (spatial flatten, not batch squeeze).
78+
// output_pattern[0] == -1 means the first dim is inferred (batch absorbed/squeezed).
79+
// output_pattern[0] == 0 or 1 means batch=1 is explicitly kept.
80+
// e.g. [1,C,H,W] -> [1,C,H*W] has pattern [1,-1] => reject
81+
// [1,N,H,W,C] -> [N,H,W,C] has pattern [-1,H,W,C] => allow
82+
auto first_out_pattern = prim->output_pattern[0];
83+
if (first_out_pattern == 0 || first_out_pattern == 1)
8184
return false;
8285
return true;
8386
}

src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,3 +1911,87 @@ TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_spatial_
19111911
for (size_t i = 0; i < slice_elems; i++)
19121912
ASSERT_FLOAT_EQ(out1[i], input_data[1 * slice_elems + i]) << "Branch 1 mismatch at " << i;
19131913
}
1914+
1915+
// SwinTransformer layers.3 pattern: crop [1,1,H,W,C] → Reshape [-1,H,W,C] (base mode)
1916+
// Runtime reshape output[0] == 1, but output_pattern[0] == -1 (batch squeeze).
1917+
// In-place crop must still work correctly.
1918+
TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_reshape_window_count_one) {
1919+
auto& engine = get_test_engine();
1920+
tests::random_generator rg(GET_SUITE_NAME);
1921+
1922+
// 3-way QKV split on axis 0 with f=1 (like SwinTransformer window_count=1)
1923+
const size_t num_heads = 3, seq_len = 4, head_dim = 2;
1924+
const size_t slice_elems = 1 * num_heads * seq_len * head_dim;
1925+
1926+
auto in_layout = layout{ov::PartialShape{3, -1, static_cast<int64_t>(num_heads),
1927+
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)},
1928+
data_types::f32, format::bfzyx};
1929+
auto input_mem = engine.allocate_memory({{3, 1, static_cast<int64_t>(num_heads),
1930+
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)},
1931+
data_types::f32, format::bfzyx});
1932+
auto axis_mem = engine.allocate_memory({{}, data_types::i64, format::bfyx});
1933+
auto splits_length_mem = engine.allocate_memory({{3}, data_types::i64, format::bfyx});
1934+
1935+
const int64_t axis = 0;
1936+
1937+
auto input_data = rg.generate_random_1d<float>(input_mem->count(), -1.f, 1.f);
1938+
set_values(input_mem, input_data);
1939+
set_values<int64_t>(axis_mem, {axis});
1940+
set_values<int64_t>(splits_length_mem, {1, 1, 1});
1941+
1942+
// output_pattern[0] == -1: batch dim is squeezed. Runtime output[0] resolves to 1.
1943+
const std::vector<int64_t> squeeze_pattern = {-1, static_cast<int64_t>(num_heads),
1944+
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)};
1945+
const ov::PartialShape squeeze_out_shape = {-1, static_cast<int64_t>(num_heads),
1946+
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)};
1947+
1948+
cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split;
1949+
topology topology(
1950+
input_layout("input", in_layout),
1951+
data("axis", axis_mem),
1952+
data("splits_length", splits_length_mem),
1953+
// Q branch
1954+
crop("crop_q", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 0, axis),
1955+
reshape("reshape_q", input_info("crop_q"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
1956+
reorder("output_q", input_info("reshape_q"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true),
1957+
// K branch
1958+
crop("crop_k", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 1, axis),
1959+
reshape("reshape_k", input_info("crop_k"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
1960+
reorder("output_k", input_info("reshape_k"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true),
1961+
// V branch
1962+
crop("crop_v", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 2, axis),
1963+
reshape("reshape_v", input_info("crop_v"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
1964+
reorder("output_v", input_info("reshape_v"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true)
1965+
);
1966+
1967+
auto config = get_test_default_config(engine);
1968+
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
1969+
config.set_property(ov::intel_gpu::optimize_data(true));
1970+
network network(engine, topology, config);
1971+
network.set_input_data("input", input_mem);
1972+
1973+
auto outputs = network.execute();
1974+
1975+
// Crops must still be in-place optimized despite runtime reshape output[0] == 1
1976+
ASSERT_TRUE(network.get_primitive("crop_q")->can_be_optimized());
1977+
ASSERT_TRUE(network.get_primitive("crop_k")->can_be_optimized());
1978+
ASSERT_TRUE(network.get_primitive("crop_v")->can_be_optimized());
1979+
1980+
auto q_mem = outputs.at("output_q").get_memory();
1981+
cldnn::mem_lock<float> q_out(q_mem, get_test_stream());
1982+
auto k_mem = outputs.at("output_k").get_memory();
1983+
cldnn::mem_lock<float> k_out(k_mem, get_test_stream());
1984+
auto v_mem = outputs.at("output_v").get_memory();
1985+
cldnn::mem_lock<float> v_out(v_mem, get_test_stream());
1986+
1987+
ASSERT_EQ(q_out.size(), slice_elems);
1988+
ASSERT_EQ(k_out.size(), slice_elems);
1989+
ASSERT_EQ(v_out.size(), slice_elems);
1990+
1991+
for (size_t i = 0; i < slice_elems; i++)
1992+
ASSERT_FLOAT_EQ(q_out[i], input_data[0 * slice_elems + i]) << "Q mismatch at " << i;
1993+
for (size_t i = 0; i < slice_elems; i++)
1994+
ASSERT_FLOAT_EQ(k_out[i], input_data[1 * slice_elems + i]) << "K mismatch at " << i;
1995+
for (size_t i = 0; i < slice_elems; i++)
1996+
ASSERT_FLOAT_EQ(v_out[i], input_data[2 * slice_elems + i]) << "V mismatch at " << i;
1997+
}

0 commit comments

Comments
 (0)