Skip to content

Commit 25a36d1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1a2792b commit 25a36d1

1 file changed

Lines changed: 20 additions & 15 deletions

File tree

source/api_cc/src/DeepPotPT.cc

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#ifdef BUILD_PYTORCH
33
#include "DeepPotPT.h"
44

5-
#include <torch/csrc/jit/runtime/jit_exception.h>
65
#include <torch/csrc/autograd/profiler.h>
6+
#include <torch/csrc/jit/runtime/jit_exception.h>
77

88
#include <cstdint>
99

@@ -88,30 +88,32 @@ void DeepPotPT::init(const std::string& model,
8888
const char* env_profiler = std::getenv("DP_PROFILER");
8989
if (env_profiler && *env_profiler) {
9090
using torch::profiler::impl::ActivityType;
91+
using torch::profiler::impl::ExperimentalConfig;
9192
using torch::profiler::impl::ProfilerConfig;
9293
using torch::profiler::impl::ProfilerState;
93-
using torch::profiler::impl::ExperimentalConfig;
9494
std::set<ActivityType> activities{ActivityType::CPU};
95-
if (gpu_enabled) activities.insert(ActivityType::CUDA);
95+
if (gpu_enabled) {
96+
activities.insert(ActivityType::CUDA);
97+
}
9698
profiler_file = std::string(env_profiler);
9799
if (gpu_enabled) {
98100
profiler_file += "_gpu" + std::to_string(gpu_id);
99101
}
100102
profiler_file += ".json";
101103
ExperimentalConfig exp_cfg;
102-
ProfilerConfig cfg(
103-
ProfilerState::KINETO,
104-
false, // report_input_shapes,
105-
false, // profile_memory,
106-
true, // with_stack,
107-
false, // with_flops,
108-
true, // with_modules,
109-
exp_cfg,
110-
std::string() // trace_id
104+
ProfilerConfig cfg(ProfilerState::KINETO,
105+
false, // report_input_shapes,
106+
false, // profile_memory,
107+
true, // with_stack,
108+
false, // with_flops,
109+
true, // with_modules,
110+
exp_cfg,
111+
std::string() // trace_id
111112
);
112113
torch::autograd::profiler::prepareProfiler(cfg, activities);
113114
torch::autograd::profiler::enableProfiler(cfg, activities);
114-
std::cout << "PyTorch profiler enabled, output file: " << profiler_file << std::endl;
115+
std::cout << "PyTorch profiler enabled, output file: " << profiler_file
116+
<< std::endl;
115117
profiler_enabled = true;
116118
}
117119
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
@@ -151,8 +153,11 @@ void DeepPotPT::init(const std::string& model,
151153
DeepPotPT::~DeepPotPT() {
152154
if (profiler_enabled) {
153155
auto result = torch::autograd::profiler::disableProfiler();
154-
if (result) result->save(profiler_file);
155-
std::cout << "PyTorch profiler result saved to " << profiler_file << std::endl;
156+
if (result) {
157+
result->save(profiler_file);
158+
}
159+
std::cout << "PyTorch profiler result saved to " << profiler_file
160+
<< std::endl;
156161
}
157162
}
158163

0 commit comments

Comments
 (0)