@@ -1911,3 +1911,87 @@ TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_spatial_
19111911 for (size_t i = 0 ; i < slice_elems; i++)
19121912 ASSERT_FLOAT_EQ (out1[i], input_data[1 * slice_elems + i]) << " Branch 1 mismatch at " << i;
19131913}
1914+
1915+ // SwinTransformer layers.3 pattern: crop [1,1,H,W,C] → Reshape [-1,H,W,C] (base mode)
1916+ // Runtime reshape output[0] == 1, but output_pattern[0] == -1 (batch squeeze).
1917+ // In-place crop must still work correctly.
1918+ TEST (prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_reshape_window_count_one) {
1919+ auto & engine = get_test_engine ();
1920+ tests::random_generator rg (GET_SUITE_NAME);
1921+
1922+ // 3-way QKV split on axis 0 with f=1 (like SwinTransformer window_count=1)
1923+ const size_t num_heads = 3 , seq_len = 4 , head_dim = 2 ;
1924+ const size_t slice_elems = 1 * num_heads * seq_len * head_dim;
1925+
1926+ auto in_layout = layout{ov::PartialShape{3 , -1 , static_cast <int64_t >(num_heads),
1927+ static_cast <int64_t >(seq_len), static_cast <int64_t >(head_dim)},
1928+ data_types::f32 , format::bfzyx};
1929+ auto input_mem = engine.allocate_memory ({{3 , 1 , static_cast <int64_t >(num_heads),
1930+ static_cast <int64_t >(seq_len), static_cast <int64_t >(head_dim)},
1931+ data_types::f32 , format::bfzyx});
1932+ auto axis_mem = engine.allocate_memory ({{}, data_types::i64 , format::bfyx});
1933+ auto splits_length_mem = engine.allocate_memory ({{3 }, data_types::i64 , format::bfyx});
1934+
1935+ const int64_t axis = 0 ;
1936+
1937+ auto input_data = rg.generate_random_1d <float >(input_mem->count (), -1 .f , 1 .f );
1938+ set_values (input_mem, input_data);
1939+ set_values<int64_t >(axis_mem, {axis});
1940+ set_values<int64_t >(splits_length_mem, {1 , 1 , 1 });
1941+
1942+ // output_pattern[0] == -1: batch dim is squeezed. Runtime output[0] resolves to 1.
1943+ const std::vector<int64_t > squeeze_pattern = {-1 , static_cast <int64_t >(num_heads),
1944+ static_cast <int64_t >(seq_len), static_cast <int64_t >(head_dim)};
1945+ const ov::PartialShape squeeze_out_shape = {-1 , static_cast <int64_t >(num_heads),
1946+ static_cast <int64_t >(seq_len), static_cast <int64_t >(head_dim)};
1947+
1948+ cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split;
1949+ topology topology (
1950+ input_layout (" input" , in_layout),
1951+ data (" axis" , axis_mem),
1952+ data (" splits_length" , splits_length_mem),
1953+ // Q branch
1954+ crop (" crop_q" , {input_info (" input" ), input_info (" axis" ), input_info (" splits_length" )}, cldnn::tensor (1 ), cldnn::tensor (0 ), op_mode, 0 , axis),
1955+ reshape (" reshape_q" , input_info (" crop_q" ), false , squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
1956+ reorder (" output_q" , input_info (" reshape_q" ), format::bfyx, data_types::f32 , std::vector<float >(), reorder_mean_mode::subtract, padding (), true ),
1957+ // K branch
1958+ crop (" crop_k" , {input_info (" input" ), input_info (" axis" ), input_info (" splits_length" )}, cldnn::tensor (1 ), cldnn::tensor (0 ), op_mode, 1 , axis),
1959+ reshape (" reshape_k" , input_info (" crop_k" ), false , squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
1960+ reorder (" output_k" , input_info (" reshape_k" ), format::bfyx, data_types::f32 , std::vector<float >(), reorder_mean_mode::subtract, padding (), true ),
1961+ // V branch
1962+ crop (" crop_v" , {input_info (" input" ), input_info (" axis" ), input_info (" splits_length" )}, cldnn::tensor (1 ), cldnn::tensor (0 ), op_mode, 2 , axis),
1963+ reshape (" reshape_v" , input_info (" crop_v" ), false , squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
1964+ reorder (" output_v" , input_info (" reshape_v" ), format::bfyx, data_types::f32 , std::vector<float >(), reorder_mean_mode::subtract, padding (), true )
1965+ );
1966+
1967+ auto config = get_test_default_config (engine);
1968+ config.set_property (ov::intel_gpu::allow_new_shape_infer (true ));
1969+ config.set_property (ov::intel_gpu::optimize_data (true ));
1970+ network network (engine, topology, config);
1971+ network.set_input_data (" input" , input_mem);
1972+
1973+ auto outputs = network.execute ();
1974+
1975+ // Crops must still be in-place optimized despite runtime reshape output[0] == 1
1976+ ASSERT_TRUE (network.get_primitive (" crop_q" )->can_be_optimized ());
1977+ ASSERT_TRUE (network.get_primitive (" crop_k" )->can_be_optimized ());
1978+ ASSERT_TRUE (network.get_primitive (" crop_v" )->can_be_optimized ());
1979+
1980+ auto q_mem = outputs.at (" output_q" ).get_memory ();
1981+ cldnn::mem_lock<float > q_out (q_mem, get_test_stream ());
1982+ auto k_mem = outputs.at (" output_k" ).get_memory ();
1983+ cldnn::mem_lock<float > k_out (k_mem, get_test_stream ());
1984+ auto v_mem = outputs.at (" output_v" ).get_memory ();
1985+ cldnn::mem_lock<float > v_out (v_mem, get_test_stream ());
1986+
1987+ ASSERT_EQ (q_out.size (), slice_elems);
1988+ ASSERT_EQ (k_out.size (), slice_elems);
1989+ ASSERT_EQ (v_out.size (), slice_elems);
1990+
1991+ for (size_t i = 0 ; i < slice_elems; i++)
1992+ ASSERT_FLOAT_EQ (q_out[i], input_data[0 * slice_elems + i]) << " Q mismatch at " << i;
1993+ for (size_t i = 0 ; i < slice_elems; i++)
1994+ ASSERT_FLOAT_EQ (k_out[i], input_data[1 * slice_elems + i]) << " K mismatch at " << i;
1995+ for (size_t i = 0 ; i < slice_elems; i++)
1996+ ASSERT_FLOAT_EQ (v_out[i], input_data[2 * slice_elems + i]) << " V mismatch at " << i;
1997+ }
0 commit comments