Skip to content

Commit ce76ff5

Browse files
committed
fix setup test
1 parent 3c8eed4 commit ce76ff5

1 file changed

Lines changed: 186 additions & 161 deletions

File tree

test/single_layer/test_poolinglayer.cpp

Lines changed: 186 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -513,13 +513,13 @@ struct PoolingTestParams {
513513
std::string description;
514514
};
515515

516-
class PoolingWithTBBTest
516+
class PoolingParametrizedTest
517517
: public BaseTestFixture,
518-
public ::testing::WithParamInterface<PoolingTestParams> {};
518+
public ::testing::WithParamInterface<
519+
std::tuple<PoolingTestParams, RuntimeOptions>> {};
519520

520-
TEST_P(PoolingWithTBBTest, test_pooling_with_tbb) {
521-
auto params = GetParam();
522-
auto tbb_options = setTBBOptions();
521+
TEST_P(PoolingParametrizedTest, test_pooling_with_different_backends) {
522+
auto [params, runtime_options] = GetParam();
523523

524524
PoolingLayer layer(params.pool_shape, params.strides, params.pads,
525525
params.dilations, params.ceil_mode, params.pooling_type);
@@ -537,166 +537,191 @@ TEST_P(PoolingWithTBBTest, test_pooling_with_tbb) {
537537
std::vector<Tensor> inputs{input};
538538
std::vector<Tensor> outputs{output};
539539

540-
layer.run(inputs, outputs, tbb_options);
540+
layer.run(inputs, outputs, runtime_options);
541541

542542
auto output_data = *outputs[0].as<float>();
543543
expectVectorsNear(output_data, params.expected_output, 1e-5f);
544544
}
545545

546546
INSTANTIATE_TEST_SUITE_P(
547-
PoolingTBBTests, PoolingWithTBBTest,
548-
::testing::Values(
549-
PoolingTestParams{BaseTestFixture::basic1DData(),
550-
BaseTestFixture::basic1DShape(),
551-
{3},
552-
{2},
553-
{0, 0, 0, 0},
554-
{1, 1},
555-
false,
556-
"average",
557-
BaseTestFixture::get1DAverageExpected(),
558-
"1D_Avg_Stride2_TBB"},
559-
PoolingTestParams{BaseTestFixture::basic1DData(),
560-
BaseTestFixture::basic1DShape(),
561-
{3},
562-
{1},
563-
{1, 1, 0, 0},
564-
{1, 1},
565-
false,
566-
"average",
567-
{8.5f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.5f},
568-
"1D_Avg_Stride1_Padding_TBB"},
569-
PoolingTestParams{BaseTestFixture::basic1DData(),
570-
BaseTestFixture::basic1DShape(),
571-
{3},
572-
{2},
573-
{0, 0, 0, 0},
574-
{1, 1},
575-
false,
576-
"max",
577-
{9.0f, 7.0f, 5.0f},
578-
"1D_Max_Stride2_TBB"},
579-
PoolingTestParams{BaseTestFixture::basic1DData(),
580-
BaseTestFixture::basic1DShape(),
581-
{3},
582-
{3},
583-
{0, 0, 0, 0},
584-
{1, 1},
585-
false,
586-
"average",
587-
{8.0f, 5.0f},
588-
"1D_Avg_Stride3_TBB"},
589-
PoolingTestParams{BaseTestFixture::ascending1DData(),
590-
BaseTestFixture::ascending1DShape(),
591-
{3},
592-
{2},
593-
{0, 0, 0, 0},
594-
{1, 1},
595-
false,
596-
"average",
597-
{2.0f, 4.0f, 6.0f, 8.0f},
598-
"1D_Ascending_Avg_Stride2_TBB"},
599-
PoolingTestParams{BaseTestFixture::mixed1DData(),
600-
BaseTestFixture::ascending1DShape(),
601-
{3},
602-
{2},
603-
{0, 0, 0, 0},
604-
{1, 1},
605-
false,
606-
"max",
607-
{0.0f, 4.0f, 4.0f, 3.0f},
608-
"1D_Mixed_Max_Stride2_TBB"},
609-
PoolingTestParams{BaseTestFixture::basic2DData4x4(),
610-
BaseTestFixture::basic2DShape4x4(),
611-
{2, 2},
612-
{1, 1},
613-
{0, 0, 0, 0},
614-
{1, 1},
615-
false,
616-
"average",
617-
BaseTestFixture::get2DAverageStride1Expected(),
618-
"2D_Avg_Stride1_TBB"},
619-
PoolingTestParams{BaseTestFixture::basic2DData4x4(),
620-
BaseTestFixture::basic2DShape4x4(),
621-
{2, 2},
622-
{2, 2},
623-
{0, 0, 0, 0},
624-
{1, 1},
625-
false,
626-
"average",
627-
{6.5f, 4.5f, 4.5f, 6.5f},
628-
"2D_Avg_Stride2_TBB"},
629-
PoolingTestParams{BaseTestFixture::basic2DData3x3(),
630-
BaseTestFixture::basic2DShape3x3(),
631-
{2, 2},
632-
{1, 1},
633-
{0, 0, 0, 0},
634-
{1, 1},
635-
false,
636-
"max",
637-
{9.0f, 8.0f, 5.0f, 4.0f},
638-
"2D_Max_Stride1_TBB"},
639-
PoolingTestParams{BaseTestFixture::small2DData2x2(),
640-
BaseTestFixture::small2DShape2x2(),
641-
{2, 2},
642-
{1, 1},
643-
{0, 0, 0, 0},
644-
{1, 1},
645-
false,
646-
"average",
647-
{2.5f},
648-
"2D_Small_Avg_Stride1_TBB"},
649-
PoolingTestParams{
650-
BaseTestFixture::small2DData2x2(),
651-
BaseTestFixture::small2DShape2x2(),
652-
{2, 2},
653-
{1, 1},
654-
{1, 1, 1, 1},
655-
{1, 1},
656-
false,
657-
"average",
658-
{1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f, 3.0f, 3.5f, 4.0f},
659-
"2D_Small_Avg_Padding_TBB"},
660-
PoolingTestParams{BaseTestFixture::medium2DData5x5(),
661-
BaseTestFixture::medium2DShape5x5(),
662-
{3, 3},
663-
{2, 2},
664-
{0, 0, 0, 0},
665-
{1, 1},
666-
false,
667-
"average",
668-
{7.0f, 9.0f, 17.0f, 19.0f},
669-
"2D_Medium_Avg_3x3_Stride2_TBB"},
670-
PoolingTestParams{BaseTestFixture::zero2DData3x3(),
671-
BaseTestFixture::zero2DShape3x3(),
672-
{2, 2},
673-
{1, 1},
674-
{0, 0, 0, 0},
675-
{1, 1},
676-
false,
677-
"max",
678-
{0.0f, 0.0f, 0.0f, 0.0f},
679-
"2D_Zero_Max_TBB"},
680-
PoolingTestParams{BaseTestFixture::constant2DData4x4(7.0f),
681-
BaseTestFixture::constant2DShape4x4(),
682-
{2, 2},
683-
{1, 1},
684-
{0, 0, 0, 0},
685-
{1, 1},
686-
false,
687-
"average",
688-
std::vector<float>(9, 7.0f),
689-
"2D_Constant_Avg_TBB"},
690-
PoolingTestParams{BaseTestFixture::basic2DData4x4(),
691-
BaseTestFixture::basic2DShape4x4(),
692-
{2, 2},
693-
{1, 1},
694-
{0, 0, 0, 0},
695-
{2, 2},
696-
false,
697-
"max",
698-
{9.0f, 8.0f, 8.0f, 9.0f},
699-
"2D_Max_Dilation2_TBB"}),
700-
[](const ::testing::TestParamInfo<PoolingTestParams>& info) {
701-
return info.param.description;
547+
PoolingTests, PoolingParametrizedTest,
548+
::testing::Combine(
549+
::testing::Values(
550+
PoolingTestParams{BaseTestFixture::basic1DData(),
551+
BaseTestFixture::basic1DShape(),
552+
{3},
553+
{2},
554+
{0, 0, 0, 0},
555+
{1, 1},
556+
false,
557+
"average",
558+
BaseTestFixture::get1DAverageExpected(),
559+
"1D_Avg_Stride2"},
560+
PoolingTestParams{BaseTestFixture::basic1DData(),
561+
BaseTestFixture::basic1DShape(),
562+
{3},
563+
{1},
564+
{1, 1, 0, 0},
565+
{1, 1},
566+
false,
567+
"average",
568+
{8.5f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.5f},
569+
"1D_Avg_Stride1_Padding"},
570+
PoolingTestParams{BaseTestFixture::basic1DData(),
571+
BaseTestFixture::basic1DShape(),
572+
{3},
573+
{2},
574+
{0, 0, 0, 0},
575+
{1, 1},
576+
false,
577+
"max",
578+
{9.0f, 7.0f, 5.0f},
579+
"1D_Max_Stride2"},
580+
PoolingTestParams{BaseTestFixture::basic1DData(),
581+
BaseTestFixture::basic1DShape(),
582+
{3},
583+
{3},
584+
{0, 0, 0, 0},
585+
{1, 1},
586+
false,
587+
"average",
588+
{8.0f, 5.0f},
589+
"1D_Avg_Stride3"},
590+
PoolingTestParams{BaseTestFixture::ascending1DData(),
591+
BaseTestFixture::ascending1DShape(),
592+
{3},
593+
{2},
594+
{0, 0, 0, 0},
595+
{1, 1},
596+
false,
597+
"average",
598+
{2.0f, 4.0f, 6.0f, 8.0f},
599+
"1D_Ascending_Avg_Stride2"},
600+
PoolingTestParams{BaseTestFixture::mixed1DData(),
601+
BaseTestFixture::ascending1DShape(),
602+
{3},
603+
{2},
604+
{0, 0, 0, 0},
605+
{1, 1},
606+
false,
607+
"max",
608+
{0.0f, 4.0f, 4.0f, 3.0f},
609+
"1D_Mixed_Max_Stride2"},
610+
PoolingTestParams{BaseTestFixture::basic2DData4x4(),
611+
BaseTestFixture::basic2DShape4x4(),
612+
{2, 2},
613+
{1, 1},
614+
{0, 0, 0, 0},
615+
{1, 1},
616+
false,
617+
"average",
618+
BaseTestFixture::get2DAverageStride1Expected(),
619+
"2D_Avg_Stride1"},
620+
PoolingTestParams{BaseTestFixture::basic2DData4x4(),
621+
BaseTestFixture::basic2DShape4x4(),
622+
{2, 2},
623+
{2, 2},
624+
{0, 0, 0, 0},
625+
{1, 1},
626+
false,
627+
"average",
628+
{6.5f, 4.5f, 4.5f, 6.5f},
629+
"2D_Avg_Stride2"},
630+
PoolingTestParams{BaseTestFixture::basic2DData3x3(),
631+
BaseTestFixture::basic2DShape3x3(),
632+
{2, 2},
633+
{1, 1},
634+
{0, 0, 0, 0},
635+
{1, 1},
636+
false,
637+
"max",
638+
{9.0f, 8.0f, 5.0f, 4.0f},
639+
"2D_Max_Stride1"},
640+
PoolingTestParams{BaseTestFixture::small2DData2x2(),
641+
BaseTestFixture::small2DShape2x2(),
642+
{2, 2},
643+
{1, 1},
644+
{0, 0, 0, 0},
645+
{1, 1},
646+
false,
647+
"average",
648+
{2.5f},
649+
"2D_Small_Avg_Stride1"},
650+
PoolingTestParams{
651+
BaseTestFixture::small2DData2x2(),
652+
BaseTestFixture::small2DShape2x2(),
653+
{2, 2},
654+
{1, 1},
655+
{1, 1, 1, 1},
656+
{1, 1},
657+
false,
658+
"average",
659+
{1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f, 3.0f, 3.5f, 4.0f},
660+
"2D_Small_Avg_Padding"},
661+
PoolingTestParams{BaseTestFixture::medium2DData5x5(),
662+
BaseTestFixture::medium2DShape5x5(),
663+
{3, 3},
664+
{2, 2},
665+
{0, 0, 0, 0},
666+
{1, 1},
667+
false,
668+
"average",
669+
{7.0f, 9.0f, 17.0f, 19.0f},
670+
"2D_Medium_Avg_3x3_Stride2"},
671+
PoolingTestParams{BaseTestFixture::zero2DData3x3(),
672+
BaseTestFixture::zero2DShape3x3(),
673+
{2, 2},
674+
{1, 1},
675+
{0, 0, 0, 0},
676+
{1, 1},
677+
false,
678+
"max",
679+
{0.0f, 0.0f, 0.0f, 0.0f},
680+
"2D_Zero_Max"},
681+
PoolingTestParams{BaseTestFixture::constant2DData4x4(7.0f),
682+
BaseTestFixture::constant2DShape4x4(),
683+
{2, 2},
684+
{1, 1},
685+
{0, 0, 0, 0},
686+
{1, 1},
687+
false,
688+
"average",
689+
std::vector<float>(9, 7.0f),
690+
"2D_Constant_Avg"},
691+
PoolingTestParams{BaseTestFixture::basic2DData4x4(),
692+
BaseTestFixture::basic2DShape4x4(),
693+
{2, 2},
694+
{1, 1},
695+
{0, 0, 0, 0},
696+
{2, 2},
697+
false,
698+
"max",
699+
{9.0f, 8.0f, 8.0f, 9.0f},
700+
"2D_Max_Dilation2"}),
701+
::testing::Values(BaseTestFixture::setTBBOptions(),
702+
BaseTestFixture::setOmpOptions(),
703+
BaseTestFixture::setSeqOptions(),
704+
BaseTestFixture::setSTLOptions(),
705+
BaseTestFixture::setKokkosOptions())),
706+
[](const ::testing::TestParamInfo<
707+
std::tuple<PoolingTestParams, RuntimeOptions>>& info) {
708+
const auto& params = std::get<0>(info.param);
709+
const auto& options = std::get<1>(info.param);
710+
711+
std::string name = params.description + "_";
712+
if (options.par_backend == ParBackend::kTbb) {
713+
name += "TBB";
714+
} else if (options.par_backend == ParBackend::kOmp) {
715+
name += "OMP";
716+
} else if (options.par_backend == ParBackend::kThreads) {
717+
name += "STL";
718+
} else if (options.par_backend == ParBackend::kKokkos) {
719+
name += "Kokkos";
720+
} else {
721+
name += "Seq";
722+
}
723+
724+
std::replace(name.begin(), name.end(), ' ', '_');
725+
std::replace(name.begin(), name.end(), '-', '_');
726+
return name;
702727
});

0 commit comments

Comments
 (0)