|
2 | 2 | #ifdef BUILD_PYTORCH |
3 | 3 | #include "DeepPotPT.h" |
4 | 4 |
|
5 | | -#include <torch/csrc/jit/runtime/jit_exception.h> |
6 | 5 | #include <torch/csrc/autograd/profiler.h> |
| 6 | +#include <torch/csrc/jit/runtime/jit_exception.h> |
7 | 7 |
|
8 | 8 | #include <cstdint> |
9 | 9 |
|
@@ -88,30 +88,32 @@ void DeepPotPT::init(const std::string& model, |
88 | 88 | const char* env_profiler = std::getenv("DP_PROFILER"); |
89 | 89 | if (env_profiler && *env_profiler) { |
90 | 90 | using torch::profiler::impl::ActivityType; |
| 91 | + using torch::profiler::impl::ExperimentalConfig; |
91 | 92 | using torch::profiler::impl::ProfilerConfig; |
92 | 93 | using torch::profiler::impl::ProfilerState; |
93 | | - using torch::profiler::impl::ExperimentalConfig; |
94 | 94 | std::set<ActivityType> activities{ActivityType::CPU}; |
95 | | - if (gpu_enabled) activities.insert(ActivityType::CUDA); |
| 95 | + if (gpu_enabled) { |
| 96 | + activities.insert(ActivityType::CUDA); |
| 97 | + } |
96 | 98 | profiler_file = std::string(env_profiler); |
97 | 99 | if (gpu_enabled) { |
98 | 100 | profiler_file += "_gpu" + std::to_string(gpu_id); |
99 | 101 | } |
100 | 102 | profiler_file += ".json"; |
101 | 103 | ExperimentalConfig exp_cfg; |
102 | | - ProfilerConfig cfg( |
103 | | - ProfilerState::KINETO, |
104 | | - false, // report_input_shapes, |
105 | | - false, // profile_memory, |
106 | | - true, // with_stack, |
107 | | - false, // with_flops, |
108 | | - true, // with_modules, |
109 | | - exp_cfg, |
110 | | - std::string() // trace_id |
| 104 | + ProfilerConfig cfg(ProfilerState::KINETO, |
| 105 | + false, // report_input_shapes, |
| 106 | + false, // profile_memory, |
| 107 | + true, // with_stack, |
| 108 | + false, // with_flops, |
| 109 | + true, // with_modules, |
| 110 | + exp_cfg, |
| 111 | + std::string() // trace_id |
111 | 112 | ); |
112 | 113 | torch::autograd::profiler::prepareProfiler(cfg, activities); |
113 | 114 | torch::autograd::profiler::enableProfiler(cfg, activities); |
114 | | - std::cout << "PyTorch profiler enabled, output file: " << profiler_file << std::endl; |
| 115 | + std::cout << "PyTorch profiler enabled, output file: " << profiler_file |
| 116 | + << std::endl; |
115 | 117 | profiler_enabled = true; |
116 | 118 | } |
117 | 119 | std::unordered_map<std::string, std::string> metadata = {{"type", ""}}; |
@@ -151,8 +153,11 @@ void DeepPotPT::init(const std::string& model, |
151 | 153 | DeepPotPT::~DeepPotPT() { |
152 | 154 | if (profiler_enabled) { |
153 | 155 | auto result = torch::autograd::profiler::disableProfiler(); |
154 | | - if (result) result->save(profiler_file); |
155 | | - std::cout << "PyTorch profiler result saved to " << profiler_file << std::endl; |
| 156 | + if (result) { |
| 157 | + result->save(profiler_file); |
| 158 | + } |
| 159 | + std::cout << "PyTorch profiler result saved to " << profiler_file |
| 160 | + << std::endl; |
156 | 161 | } |
157 | 162 | } |
158 | 163 |
|
|
0 commit comments