Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,17 @@ void ScaledAttnLayerGPUTest::SetUp() {
manager.run_passes(functionRefs);

auto it = std::find_if(inputShapes[1].second.begin(), inputShapes[1].second.end(), [&](const ov::Shape& shape){
return shape[0] >= 128 || shape[2] >= 384 || shape[3] >= 128;
if (shape.empty()) {
return false;
}

const auto rank = shape.size();
const auto seq_idx = rank >= 2 ? rank - 2 : 0;
const auto head_idx = rank - 1;
return shape[0] >= 128 || shape[seq_idx] >= 384 || shape[head_idx] >= 128;
});

bool has_diff_head_size = inputShapes[1].first.begin()[3] != inputShapes[2].first.begin()[3];
bool has_diff_head_size = inputShapes[1].first[-1] != inputShapes[2].first[-1];

bool has_long_seq = it != inputShapes[1].second.end();

Expand Down Expand Up @@ -684,6 +691,9 @@ const auto dynamic_shape_params_4D = testing::Combine(testing::Values(ov::elemen
testing::ValuesIn({disable_transpose, transpose_value}),
testing::Values(false));


// test 10

INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnDynamic4D_GPU,
ScaledAttnLayerGPUTest,
dynamic_shape_params_4D,
Expand Down
Loading