|
1 | 1 | #pragma once |
2 | 2 |
|
| 3 | +#include <benchmark/benchmark.h> |
3 | 4 | #include <gtest/gtest.h> |
4 | 5 | #include <omp.h> |
5 | 6 | #include <tbb/tick_count.h> |
6 | 7 |
|
7 | 8 | #include <chrono> |
8 | 9 | #include <cstddef> |
| 10 | +#include <cstdlib> |
9 | 11 | #include <functional> |
10 | 12 | #include <sstream> |
11 | 13 | #include <stdexcept> |
|
21 | 23 |
|
22 | 24 | namespace ppc::util { |
23 | 25 |
|
24 | | -double GetTimeMPI(); |
25 | | -int GetMPIRank(); |
| 26 | +namespace detail { |
| 27 | + |
| 28 | +inline bool ContainsFilterToken(std::string_view value, const char *filter_env) { |
| 29 | + if (filter_env == nullptr || std::string_view(filter_env).empty()) { |
| 30 | + return true; |
| 31 | + } |
| 32 | + return value.find(filter_env) != std::string_view::npos; |
| 33 | +} |
| 34 | + |
| 35 | +inline bool ShouldRunBenchmark(std::string_view test_name) { |
| 36 | + return ContainsFilterToken(test_name, std::getenv("PPC_PERF_IMPL_FILTER")) && |
| 37 | + ContainsFilterToken(test_name, std::getenv("PPC_PERF_CATEGORY_FILTER")); |
| 38 | +} |
| 39 | + |
| 40 | +inline void CheckPerfMode(ppc::performance::PerfResults::TypeOfRunning mode) { |
| 41 | + if (mode == ppc::performance::PerfResults::TypeOfRunning::kPipeline || |
| 42 | + mode == ppc::performance::PerfResults::TypeOfRunning::kTaskRun) { |
| 43 | + return; |
| 44 | + } |
| 45 | + std::stringstream err_msg; |
| 46 | + err_msg << '\n' << "The type of performance check for the task was not selected.\n"; |
| 47 | + throw std::runtime_error(err_msg.str().c_str()); |
| 48 | +} |
| 49 | + |
| 50 | +template <typename InType, typename OutType> |
| 51 | +void RunTaskPipeline(const ppc::task::TaskPtr<InType, OutType> &task) { |
| 52 | + task->Validation(); |
| 53 | + task->PreProcessing(); |
| 54 | + task->Run(); |
| 55 | + task->PostProcessing(); |
| 56 | +} |
| 57 | + |
| 58 | +template <typename InType, typename OutType> |
| 59 | +void RunTaskForBenchmark(const ppc::task::TaskPtr<InType, OutType> &task, |
| 60 | + ppc::performance::PerfResults::TypeOfRunning mode, benchmark::State &state) { |
| 61 | + task->GetStateOfTesting() = ppc::task::StateOfTesting::kPerf; |
| 62 | + if (mode == ppc::performance::PerfResults::TypeOfRunning::kPipeline) { |
| 63 | + SynchronizeMpiRanks(); |
| 64 | + state.ResumeTiming(); |
| 65 | + RunTaskPipeline(task); |
| 66 | + state.PauseTiming(); |
| 67 | + return; |
| 68 | + } |
| 69 | + |
| 70 | + task->Validation(); |
| 71 | + task->PreProcessing(); |
| 72 | + SynchronizeMpiRanks(); |
| 73 | + state.ResumeTiming(); |
| 74 | + task->Run(); |
| 75 | + state.PauseTiming(); |
| 76 | + task->PostProcessing(); |
| 77 | +} |
| 78 | + |
| 79 | +inline std::string MakeBenchmarkName(const std::string &test_name, ppc::performance::PerfResults::TypeOfRunning mode) { |
| 80 | + return test_name + "/" + ppc::performance::GetStringParamName(mode); |
| 81 | +} |
| 82 | + |
| 83 | +} // namespace detail |
26 | 84 |
|
27 | 85 | template <typename InType, typename OutType> |
28 | 86 | using PerfTestParam = std::tuple<std::function<ppc::task::TaskPtr<InType, OutType>(InType)>, std::string, |
@@ -80,31 +138,35 @@ class BaseRunPerfTests : public ::testing::TestWithParam<PerfTestParam<InType, O |
80 | 138 | // A single perf test body may execute several implementations; do not abort the enabled ones. |
81 | 139 | return; |
82 | 140 | } |
| 141 | + if (!detail::ShouldRunBenchmark(test_name)) { |
| 142 | + return; |
| 143 | + } |
| 144 | + detail::CheckPerfMode(mode); |
83 | 145 |
|
84 | 146 | const auto test_env_scope = ppc::util::test::MakePerTestEnvForCurrentGTest(test_name); |
85 | 147 |
|
86 | | - task_ = task_getter(GetTestInputData()); |
87 | | - ppc::performance::Perf perf(task_); |
88 | | - ppc::performance::PerfAttr perf_attr; |
| 148 | + const auto input_data = GetTestInputData(); |
| 149 | + task_ = task_getter(input_data); |
89 | 150 | SynchronizeMpiRanks(); |
90 | | - SetPerfAttributes(perf_attr); |
91 | | - |
92 | | - if (mode == ppc::performance::PerfResults::TypeOfRunning::kPipeline) { |
93 | | - perf.PipelineRun(perf_attr); |
94 | | - } else if (mode == ppc::performance::PerfResults::TypeOfRunning::kTaskRun) { |
95 | | - perf.TaskRun(perf_attr); |
96 | | - } else { |
97 | | - std::stringstream err_msg; |
98 | | - err_msg << '\n' << "The type of performance check for the task was not selected.\n"; |
99 | | - throw std::runtime_error(err_msg.str().c_str()); |
100 | | - } |
101 | | - |
102 | | - if (GetMPIRank() == 0) { |
103 | | - perf.PrintPerfStatistic(test_name); |
104 | | - } |
| 151 | + detail::RunTaskPipeline(task_); |
105 | 152 |
|
106 | 153 | OutType output_data = task_->GetOutput(); |
107 | 154 | ASSERT_TRUE(CheckTestOutputData(output_data)); |
| 155 | + |
| 156 | + const auto benchmark_name = detail::MakeBenchmarkName(test_name, mode); |
| 157 | + benchmark::RegisterBenchmark(benchmark_name, |
| 158 | + [task_getter, input_data, mode](benchmark::State &state) { |
| 159 | + for (auto _ : state) { |
| 160 | + state.PauseTiming(); |
| 161 | + auto task = task_getter(input_data); |
| 162 | + detail::RunTaskForBenchmark(task, mode, state); |
| 163 | + benchmark::DoNotOptimize(task->GetOutput()); |
| 164 | + state.ResumeTiming(); |
| 165 | + } |
| 166 | + }) |
| 167 | + ->UseRealTime() |
| 168 | + ->Unit(benchmark::kMillisecond) |
| 169 | + ->MinTime(0.01); |
108 | 170 | } |
109 | 171 |
|
110 | 172 | private: |
|
0 commit comments