Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
if (reshape_mode == reshape::reshape_mode::base) {
if (crop_axis == 0 && !crop_layout.get_partial_shape()[0].is_dynamic() &&
crop_layout.get_partial_shape()[0].get_length() == 1 &&
!(user_info.second.get_partial_shape()[0].is_static() &&
user_info.second.get_partial_shape()[0].get_length() == 1)) {
!reshape_desc->output_pattern.empty() &&
reshape_desc->output_pattern[0] != 0 && reshape_desc->output_pattern[0] != 1) {
// The crop produces exactly batch=1 per slice and the reshape squeezes that dim.
// The reshape absorbs that dim, so the padding axis in the output remains 0.
// output_pattern[0] == -1 means the batch dim is absorbed (squeezed).
reshape_axis = 0;
} else {
auto mul = 1;
Expand Down Expand Up @@ -767,7 +767,8 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
padding::DynamicDimsMask reshape_dyn_pad_mask;

if (crop_axis == 0 && crop_dim_val == 1 &&
!(reshape_ps[0].is_static() && reshape_ps[0].get_length() == 1)) {
!reshape_desc->output_pattern.empty() &&
reshape_desc->output_pattern[0] != 0 && reshape_desc->output_pattern[0] != 1) {
// The crop splits on the batch axis with exactly batch=1 per slice
// and the reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x].
// Padding offsets are in units of one 4D batch slice (= f*y*x elements),
Expand Down
11 changes: 7 additions & 4 deletions src/plugins/intel_gpu/src/graph/include/reshape_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
return false;
if (input_pshape[0].get_length() != 1)
return false;
// Reject if the reshape just flattens spatial dims while keeping batch=1
// (e.g. [1,C,H,W] -> [1,C,H*W]). Only allow when the batch dim is truly squeezed.
auto& out_ps = prim->output_partial_shape;
if (!out_ps[0].is_dynamic() && out_ps[0].get_length() == 1)
// Reject if the reshape preserves the batch=1 dim (spatial flatten, not batch squeeze).
// output_pattern[0] == -1 means the first dim is inferred (batch absorbed/squeezed).
// output_pattern[0] == 0 or 1 means batch=1 is explicitly kept.
// e.g. [1,C,H,W] -> [1,C,H*W] has pattern [1,-1] => reject
// [1,N,H,W,C] -> [N,H,W,C] has pattern [-1,H,W,C] => allow
auto first_out_pattern = prim->output_pattern[0];
if (first_out_pattern == 0 || first_out_pattern == 1)
return false;
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1911,3 +1911,87 @@ TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_spatial_
for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(out1[i], input_data[1 * slice_elems + i]) << "Branch 1 mismatch at " << i;
}

// SwinTransformer layers.3 pattern: crop [1,1,H,W,C] → Reshape [-1,H,W,C] (base mode)
// Runtime reshape output[0] == 1, but output_pattern[0] == -1 (batch squeeze).
// In-place crop must still work correctly.
TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_reshape_window_count_one) {
auto& engine = get_test_engine();
tests::random_generator rg(GET_SUITE_NAME);

// 3-way QKV split on axis 0 with f=1 (like SwinTransformer window_count=1)
const size_t num_heads = 3, seq_len = 4, head_dim = 2;
const size_t slice_elems = 1 * num_heads * seq_len * head_dim;

auto in_layout = layout{ov::PartialShape{3, -1, static_cast<int64_t>(num_heads),
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)},
data_types::f32, format::bfzyx};
auto input_mem = engine.allocate_memory({{3, 1, static_cast<int64_t>(num_heads),
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)},
data_types::f32, format::bfzyx});
auto axis_mem = engine.allocate_memory({{}, data_types::i64, format::bfyx});
auto splits_length_mem = engine.allocate_memory({{3}, data_types::i64, format::bfyx});

const int64_t axis = 0;

auto input_data = rg.generate_random_1d<float>(input_mem->count(), -1.f, 1.f);
set_values(input_mem, input_data);
set_values<int64_t>(axis_mem, {axis});
set_values<int64_t>(splits_length_mem, {1, 1, 1});

// output_pattern[0] == -1: batch dim is squeezed. Runtime output[0] resolves to 1.
const std::vector<int64_t> squeeze_pattern = {-1, static_cast<int64_t>(num_heads),
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)};
const ov::PartialShape squeeze_out_shape = {-1, static_cast<int64_t>(num_heads),
static_cast<int64_t>(seq_len), static_cast<int64_t>(head_dim)};

cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split;
topology topology(
input_layout("input", in_layout),
data("axis", axis_mem),
data("splits_length", splits_length_mem),
// Q branch
crop("crop_q", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 0, axis),
reshape("reshape_q", input_info("crop_q"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
reorder("output_q", input_info("reshape_q"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true),
// K branch
crop("crop_k", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 1, axis),
reshape("reshape_k", input_info("crop_k"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
reorder("output_k", input_info("reshape_k"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true),
// V branch
crop("crop_v", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 2, axis),
reshape("reshape_v", input_info("crop_v"), false, squeeze_pattern, squeeze_out_shape, cldnn::reshape::reshape_mode::base),
reorder("output_v", input_info("reshape_v"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true)
);

auto config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::intel_gpu::optimize_data(true));
network network(engine, topology, config);
network.set_input_data("input", input_mem);

auto outputs = network.execute();

// Crops must still be in-place optimized despite runtime reshape output[0] == 1
ASSERT_TRUE(network.get_primitive("crop_q")->can_be_optimized());
ASSERT_TRUE(network.get_primitive("crop_k")->can_be_optimized());
ASSERT_TRUE(network.get_primitive("crop_v")->can_be_optimized());

auto q_mem = outputs.at("output_q").get_memory();
cldnn::mem_lock<float> q_out(q_mem, get_test_stream());
auto k_mem = outputs.at("output_k").get_memory();
cldnn::mem_lock<float> k_out(k_mem, get_test_stream());
auto v_mem = outputs.at("output_v").get_memory();
cldnn::mem_lock<float> v_out(v_mem, get_test_stream());

ASSERT_EQ(q_out.size(), slice_elems);
ASSERT_EQ(k_out.size(), slice_elems);
ASSERT_EQ(v_out.size(), slice_elems);

for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(q_out[i], input_data[0 * slice_elems + i]) << "Q mismatch at " << i;
for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(k_out[i], input_data[1 * slice_elems + i]) << "K mismatch at " << i;
for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(v_out[i], input_data[2 * slice_elems + i]) << "V mismatch at " << i;
}
Loading