@@ -513,13 +513,13 @@ struct PoolingTestParams {
513513 std::string description;
514514};
515515
516- class PoolingWithTBBTest
516+ class PoolingParametrizedTest
517517 : public BaseTestFixture,
518- public ::testing::WithParamInterface<PoolingTestParams> {};
518+ public ::testing::WithParamInterface<
519+ std::tuple<PoolingTestParams, RuntimeOptions>> {};
519520
520- TEST_P (PoolingWithTBBTest, test_pooling_with_tbb) {
521- auto params = GetParam ();
522- auto tbb_options = setTBBOptions ();
521+ TEST_P (PoolingParametrizedTest, test_pooling_with_different_backends) {
522+ auto [params, runtime_options] = GetParam ();
523523
524524 PoolingLayer layer (params.pool_shape , params.strides , params.pads ,
525525 params.dilations , params.ceil_mode , params.pooling_type );
@@ -537,166 +537,191 @@ TEST_P(PoolingWithTBBTest, test_pooling_with_tbb) {
537537 std::vector<Tensor> inputs{input};
538538 std::vector<Tensor> outputs{output};
539539
540- layer.run (inputs, outputs, tbb_options );
540+ layer.run (inputs, outputs, runtime_options );
541541
542542 auto output_data = *outputs[0 ].as <float >();
543543 expectVectorsNear (output_data, params.expected_output , 1e-5f );
544544}
545545
546546INSTANTIATE_TEST_SUITE_P (
547- PoolingTBBTests, PoolingWithTBBTest,
548- ::testing::Values (
549- PoolingTestParams{BaseTestFixture::basic1DData (),
550- BaseTestFixture::basic1DShape (),
551- {3 },
552- {2 },
553- {0 , 0 , 0 , 0 },
554- {1 , 1 },
555- false ,
556- " average" ,
557- BaseTestFixture::get1DAverageExpected (),
558- " 1D_Avg_Stride2_TBB" },
559- PoolingTestParams{BaseTestFixture::basic1DData (),
560- BaseTestFixture::basic1DShape (),
561- {3 },
562- {1 },
563- {1 , 1 , 0 , 0 },
564- {1 , 1 },
565- false ,
566- " average" ,
567- {8 .5f , 8 .0f , 7 .0f , 6 .0f , 5 .0f , 4 .0f , 3 .0f , 2 .5f },
568- " 1D_Avg_Stride1_Padding_TBB" },
569- PoolingTestParams{BaseTestFixture::basic1DData (),
570- BaseTestFixture::basic1DShape (),
571- {3 },
572- {2 },
573- {0 , 0 , 0 , 0 },
574- {1 , 1 },
575- false ,
576- " max" ,
577- {9 .0f , 7 .0f , 5 .0f },
578- " 1D_Max_Stride2_TBB" },
579- PoolingTestParams{BaseTestFixture::basic1DData (),
580- BaseTestFixture::basic1DShape (),
581- {3 },
582- {3 },
583- {0 , 0 , 0 , 0 },
584- {1 , 1 },
585- false ,
586- " average" ,
587- {8 .0f , 5 .0f },
588- " 1D_Avg_Stride3_TBB" },
589- PoolingTestParams{BaseTestFixture::ascending1DData (),
590- BaseTestFixture::ascending1DShape (),
591- {3 },
592- {2 },
593- {0 , 0 , 0 , 0 },
594- {1 , 1 },
595- false ,
596- " average" ,
597- {2 .0f , 4 .0f , 6 .0f , 8 .0f },
598- " 1D_Ascending_Avg_Stride2_TBB" },
599- PoolingTestParams{BaseTestFixture::mixed1DData (),
600- BaseTestFixture::ascending1DShape (),
601- {3 },
602- {2 },
603- {0 , 0 , 0 , 0 },
604- {1 , 1 },
605- false ,
606- " max" ,
607- {0 .0f , 4 .0f , 4 .0f , 3 .0f },
608- " 1D_Mixed_Max_Stride2_TBB" },
609- PoolingTestParams{BaseTestFixture::basic2DData4x4 (),
610- BaseTestFixture::basic2DShape4x4 (),
611- {2 , 2 },
612- {1 , 1 },
613- {0 , 0 , 0 , 0 },
614- {1 , 1 },
615- false ,
616- " average" ,
617- BaseTestFixture::get2DAverageStride1Expected (),
618- " 2D_Avg_Stride1_TBB" },
619- PoolingTestParams{BaseTestFixture::basic2DData4x4 (),
620- BaseTestFixture::basic2DShape4x4 (),
621- {2 , 2 },
622- {2 , 2 },
623- {0 , 0 , 0 , 0 },
624- {1 , 1 },
625- false ,
626- " average" ,
627- {6 .5f , 4 .5f , 4 .5f , 6 .5f },
628- " 2D_Avg_Stride2_TBB" },
629- PoolingTestParams{BaseTestFixture::basic2DData3x3 (),
630- BaseTestFixture::basic2DShape3x3 (),
631- {2 , 2 },
632- {1 , 1 },
633- {0 , 0 , 0 , 0 },
634- {1 , 1 },
635- false ,
636- " max" ,
637- {9 .0f , 8 .0f , 5 .0f , 4 .0f },
638- " 2D_Max_Stride1_TBB" },
639- PoolingTestParams{BaseTestFixture::small2DData2x2 (),
640- BaseTestFixture::small2DShape2x2 (),
641- {2 , 2 },
642- {1 , 1 },
643- {0 , 0 , 0 , 0 },
644- {1 , 1 },
645- false ,
646- " average" ,
647- {2 .5f },
648- " 2D_Small_Avg_Stride1_TBB" },
649- PoolingTestParams{
650- BaseTestFixture::small2DData2x2 (),
651- BaseTestFixture::small2DShape2x2 (),
652- {2 , 2 },
653- {1 , 1 },
654- {1 , 1 , 1 , 1 },
655- {1 , 1 },
656- false ,
657- " average" ,
658- {1 .0f , 1 .5f , 2 .0f , 2 .0f , 2 .5f , 3 .0f , 3 .0f , 3 .5f , 4 .0f },
659- " 2D_Small_Avg_Padding_TBB" },
660- PoolingTestParams{BaseTestFixture::medium2DData5x5 (),
661- BaseTestFixture::medium2DShape5x5 (),
662- {3 , 3 },
663- {2 , 2 },
664- {0 , 0 , 0 , 0 },
665- {1 , 1 },
666- false ,
667- " average" ,
668- {7 .0f , 9 .0f , 17 .0f , 19 .0f },
669- " 2D_Medium_Avg_3x3_Stride2_TBB" },
670- PoolingTestParams{BaseTestFixture::zero2DData3x3 (),
671- BaseTestFixture::zero2DShape3x3 (),
672- {2 , 2 },
673- {1 , 1 },
674- {0 , 0 , 0 , 0 },
675- {1 , 1 },
676- false ,
677- " max" ,
678- {0 .0f , 0 .0f , 0 .0f , 0 .0f },
679- " 2D_Zero_Max_TBB" },
680- PoolingTestParams{BaseTestFixture::constant2DData4x4 (7 .0f ),
681- BaseTestFixture::constant2DShape4x4 (),
682- {2 , 2 },
683- {1 , 1 },
684- {0 , 0 , 0 , 0 },
685- {1 , 1 },
686- false ,
687- " average" ,
688- std::vector<float >(9 , 7 .0f ),
689- " 2D_Constant_Avg_TBB" },
690- PoolingTestParams{BaseTestFixture::basic2DData4x4 (),
691- BaseTestFixture::basic2DShape4x4 (),
692- {2 , 2 },
693- {1 , 1 },
694- {0 , 0 , 0 , 0 },
695- {2 , 2 },
696- false ,
697- " max" ,
698- {9 .0f , 8 .0f , 8 .0f , 9 .0f },
699- " 2D_Max_Dilation2_TBB" }),
700- [](const ::testing::TestParamInfo<PoolingTestParams>& info) {
701- return info.param .description ;
547+ PoolingTests, PoolingParametrizedTest,
548+ ::testing::Combine (
549+ ::testing::Values (
550+ PoolingTestParams{BaseTestFixture::basic1DData (),
551+ BaseTestFixture::basic1DShape (),
552+ {3 },
553+ {2 },
554+ {0 , 0 , 0 , 0 },
555+ {1 , 1 },
556+ false ,
557+ " average" ,
558+ BaseTestFixture::get1DAverageExpected (),
559+ " 1D_Avg_Stride2" },
560+ PoolingTestParams{BaseTestFixture::basic1DData (),
561+ BaseTestFixture::basic1DShape (),
562+ {3 },
563+ {1 },
564+ {1 , 1 , 0 , 0 },
565+ {1 , 1 },
566+ false ,
567+ " average" ,
568+ {8 .5f , 8 .0f , 7 .0f , 6 .0f , 5 .0f , 4 .0f , 3 .0f , 2 .5f },
569+ " 1D_Avg_Stride1_Padding" },
570+ PoolingTestParams{BaseTestFixture::basic1DData (),
571+ BaseTestFixture::basic1DShape (),
572+ {3 },
573+ {2 },
574+ {0 , 0 , 0 , 0 },
575+ {1 , 1 },
576+ false ,
577+ " max" ,
578+ {9 .0f , 7 .0f , 5 .0f },
579+ " 1D_Max_Stride2" },
580+ PoolingTestParams{BaseTestFixture::basic1DData (),
581+ BaseTestFixture::basic1DShape (),
582+ {3 },
583+ {3 },
584+ {0 , 0 , 0 , 0 },
585+ {1 , 1 },
586+ false ,
587+ " average" ,
588+ {8 .0f , 5 .0f },
589+ " 1D_Avg_Stride3" },
590+ PoolingTestParams{BaseTestFixture::ascending1DData (),
591+ BaseTestFixture::ascending1DShape (),
592+ {3 },
593+ {2 },
594+ {0 , 0 , 0 , 0 },
595+ {1 , 1 },
596+ false ,
597+ " average" ,
598+ {2 .0f , 4 .0f , 6 .0f , 8 .0f },
599+ " 1D_Ascending_Avg_Stride2" },
600+ PoolingTestParams{BaseTestFixture::mixed1DData (),
601+ BaseTestFixture::ascending1DShape (),
602+ {3 },
603+ {2 },
604+ {0 , 0 , 0 , 0 },
605+ {1 , 1 },
606+ false ,
607+ " max" ,
608+ {0 .0f , 4 .0f , 4 .0f , 3 .0f },
609+ " 1D_Mixed_Max_Stride2" },
610+ PoolingTestParams{BaseTestFixture::basic2DData4x4 (),
611+ BaseTestFixture::basic2DShape4x4 (),
612+ {2 , 2 },
613+ {1 , 1 },
614+ {0 , 0 , 0 , 0 },
615+ {1 , 1 },
616+ false ,
617+ " average" ,
618+ BaseTestFixture::get2DAverageStride1Expected (),
619+ " 2D_Avg_Stride1" },
620+ PoolingTestParams{BaseTestFixture::basic2DData4x4 (),
621+ BaseTestFixture::basic2DShape4x4 (),
622+ {2 , 2 },
623+ {2 , 2 },
624+ {0 , 0 , 0 , 0 },
625+ {1 , 1 },
626+ false ,
627+ " average" ,
628+ {6 .5f , 4 .5f , 4 .5f , 6 .5f },
629+ " 2D_Avg_Stride2" },
630+ PoolingTestParams{BaseTestFixture::basic2DData3x3 (),
631+ BaseTestFixture::basic2DShape3x3 (),
632+ {2 , 2 },
633+ {1 , 1 },
634+ {0 , 0 , 0 , 0 },
635+ {1 , 1 },
636+ false ,
637+ " max" ,
638+ {9 .0f , 8 .0f , 5 .0f , 4 .0f },
639+ " 2D_Max_Stride1" },
640+ PoolingTestParams{BaseTestFixture::small2DData2x2 (),
641+ BaseTestFixture::small2DShape2x2 (),
642+ {2 , 2 },
643+ {1 , 1 },
644+ {0 , 0 , 0 , 0 },
645+ {1 , 1 },
646+ false ,
647+ " average" ,
648+ {2 .5f },
649+ " 2D_Small_Avg_Stride1" },
650+ PoolingTestParams{
651+ BaseTestFixture::small2DData2x2 (),
652+ BaseTestFixture::small2DShape2x2 (),
653+ {2 , 2 },
654+ {1 , 1 },
655+ {1 , 1 , 1 , 1 },
656+ {1 , 1 },
657+ false ,
658+ " average" ,
659+ {1 .0f , 1 .5f , 2 .0f , 2 .0f , 2 .5f , 3 .0f , 3 .0f , 3 .5f , 4 .0f },
660+ " 2D_Small_Avg_Padding" },
661+ PoolingTestParams{BaseTestFixture::medium2DData5x5 (),
662+ BaseTestFixture::medium2DShape5x5 (),
663+ {3 , 3 },
664+ {2 , 2 },
665+ {0 , 0 , 0 , 0 },
666+ {1 , 1 },
667+ false ,
668+ " average" ,
669+ {7 .0f , 9 .0f , 17 .0f , 19 .0f },
670+ " 2D_Medium_Avg_3x3_Stride2" },
671+ PoolingTestParams{BaseTestFixture::zero2DData3x3 (),
672+ BaseTestFixture::zero2DShape3x3 (),
673+ {2 , 2 },
674+ {1 , 1 },
675+ {0 , 0 , 0 , 0 },
676+ {1 , 1 },
677+ false ,
678+ " max" ,
679+ {0 .0f , 0 .0f , 0 .0f , 0 .0f },
680+ " 2D_Zero_Max" },
681+ PoolingTestParams{BaseTestFixture::constant2DData4x4 (7 .0f ),
682+ BaseTestFixture::constant2DShape4x4 (),
683+ {2 , 2 },
684+ {1 , 1 },
685+ {0 , 0 , 0 , 0 },
686+ {1 , 1 },
687+ false ,
688+ " average" ,
689+ std::vector<float >(9 , 7 .0f ),
690+ " 2D_Constant_Avg" },
691+ PoolingTestParams{BaseTestFixture::basic2DData4x4 (),
692+ BaseTestFixture::basic2DShape4x4 (),
693+ {2 , 2 },
694+ {1 , 1 },
695+ {0 , 0 , 0 , 0 },
696+ {2 , 2 },
697+ false ,
698+ " max" ,
699+ {9 .0f , 8 .0f , 8 .0f , 9 .0f },
700+ " 2D_Max_Dilation2" }),
701+ ::testing::Values(BaseTestFixture::setTBBOptions(),
702+ BaseTestFixture::setOmpOptions(),
703+ BaseTestFixture::setSeqOptions(),
704+ BaseTestFixture::setSTLOptions(),
705+ BaseTestFixture::setKokkosOptions())),
706+ [](const ::testing::TestParamInfo<
707+ std::tuple<PoolingTestParams, RuntimeOptions>>& info) {
708+ const auto & params = std::get<0 >(info.param );
709+ const auto & options = std::get<1 >(info.param );
710+
711+ std::string name = params.description + " _" ;
712+ if (options.par_backend == ParBackend::kTbb ) {
713+ name += " TBB" ;
714+ } else if (options.par_backend == ParBackend::kOmp ) {
715+ name += " OMP" ;
716+ } else if (options.par_backend == ParBackend::kThreads ) {
717+ name += " STL" ;
718+ } else if (options.par_backend == ParBackend::kKokkos ) {
719+ name += " Kokkos" ;
720+ } else {
721+ name += " Seq" ;
722+ }
723+
724+ std::replace (name.begin (), name.end (), ' ' , ' _' );
725+ std::replace (name.begin (), name.end (), ' -' , ' _' );
726+ return name;
702727 });
0 commit comments