Skip to content

Commit 737878f

Browse files
Copilotcaic99
andcommitted
Update PyTorch profiler to use new torch::profiler API instead of deprecated torch::autograd::profiler
Co-authored-by: caic99 <78061359+caic99@users.noreply.github.com>
1 parent eeb7c06 commit 737878f

4 files changed

Lines changed: 44 additions & 17 deletions

File tree

source/api_cc/include/DeepPotPT.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/script.h>
55
#include <torch/torch.h>
66
#ifdef BUILD_PYTORCH
7-
#include <torch/autograd/profiler.h>
7+
#include <torch/profiler.h>
88
#endif
99

1010
#include "DeepPot.h"
@@ -347,7 +347,7 @@ class DeepPotPT : public DeepPotBackend {
347347
bool profiler_enabled;
348348
std::string profiler_output_dir;
349349
#ifdef BUILD_PYTORCH
350-
std::unique_ptr<torch::autograd::profiler::RecordProfile> profiler;
350+
std::shared_ptr<torch::profiler::Result> profiler_result;
351351
#endif
352352
/**
353353
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.

source/api_cc/include/DeepSpinPT.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/script.h>
55
#include <torch/torch.h>
66
#ifdef BUILD_PYTORCH
7-
#include <torch/autograd/profiler.h>
7+
#include <torch/profiler.h>
88
#endif
99

1010
#include "DeepSpin.h"
@@ -269,7 +269,7 @@ class DeepSpinPT : public DeepSpinBackend {
269269
bool profiler_enabled;
270270
std::string profiler_output_dir;
271271
#ifdef BUILD_PYTORCH
272-
std::unique_ptr<torch::autograd::profiler::RecordProfile> profiler;
272+
std::shared_ptr<torch::profiler::Result> profiler_result;
273273
#endif
274274
/**
275275
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.

source/api_cc/src/DeepPotPT.cc

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,11 @@ void DeepPotPT::init(const std::string& model,
119119
std::system(mkdir_cmd.c_str());
120120

121121
std::cout << "PyTorch profiler enabled. Output directory: " << profiler_output_dir << std::endl;
122-
// Initialize profiler with default configuration
123-
profiler = std::make_unique<torch::autograd::profiler::RecordProfile>();
122+
// Start profiling using new API
123+
torch::profiler::profile({
124+
torch::profiler::ProfilerActivity::CPU,
125+
torch::profiler::ProfilerActivity::CUDA,
126+
}, true, true, false); // record_shapes, profile_memory, with_stack
124127
#else
125128
std::cerr << "Warning: PyTorch profiler requested but BUILD_PYTORCH not defined" << std::endl;
126129
#endif
@@ -137,12 +140,15 @@ void DeepPotPT::init(const std::string& model,
137140
}
138141
DeepPotPT::~DeepPotPT() {
139142
#ifdef BUILD_PYTORCH
140-
if (profiler_enabled && profiler) {
143+
if (profiler_enabled) {
141144
try {
142145
// Save profiler results to file
143146
std::string output_file = profiler_output_dir + "/pytorch_profiler_trace.json";
144-
profiler->save(output_file);
145-
std::cout << "PyTorch profiler results saved to: " << output_file << std::endl;
147+
profiler_result = torch::profiler::disableProfiler();
148+
if (profiler_result) {
149+
profiler_result->save(output_file);
150+
std::cout << "PyTorch profiler results saved to: " << output_file << std::endl;
151+
}
146152
} catch (const std::exception& e) {
147153
std::cerr << "Warning: Failed to save profiler results: " << e.what() << std::endl;
148154
}
@@ -264,7 +270,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
264270
.to(device);
265271
}
266272

267-
// Profiling is automatically active when RecordProfile is constructed
273+
// Start profiling if enabled
274+
#ifdef BUILD_PYTORCH
275+
if (profiler_enabled && profiler) {
276+
profiler->step();
277+
}
278+
#endif
268279

269280
c10::Dict<c10::IValue, c10::IValue> outputs =
270281
(do_message_passing)
@@ -416,7 +427,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
416427
bool do_atom_virial_tensor = atomic;
417428
inputs.push_back(do_atom_virial_tensor);
418429

419-
// Profiling is automatically active when RecordProfile is constructed
430+
// Start profiling if enabled
431+
#ifdef BUILD_PYTORCH
432+
if (profiler_enabled && profiler) {
433+
profiler->step();
434+
}
435+
#endif
420436

421437
c10::Dict<c10::IValue, c10::IValue> outputs =
422438
module.forward(inputs).toGenericDict();

source/api_cc/src/DeepSpinPT.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,11 @@ void DeepSpinPT::init(const std::string& model,
119119
std::system(mkdir_cmd.c_str());
120120

121121
std::cout << "PyTorch profiler enabled. Output directory: " << profiler_output_dir << std::endl;
122-
// Initialize profiler with default configuration
123-
profiler = std::make_unique<torch::autograd::profiler::RecordProfile>();
122+
// Start profiling using new API
123+
torch::profiler::profile({
124+
torch::profiler::ProfilerActivity::CPU,
125+
torch::profiler::ProfilerActivity::CUDA,
126+
}, true, true, false); // record_shapes, profile_memory, with_stack
124127
#else
125128
std::cerr << "Warning: PyTorch profiler requested but BUILD_PYTORCH not defined" << std::endl;
126129
#endif
@@ -137,12 +140,15 @@ void DeepSpinPT::init(const std::string& model,
137140
}
138141
DeepSpinPT::~DeepSpinPT() {
139142
#ifdef BUILD_PYTORCH
140-
if (profiler_enabled && profiler) {
143+
if (profiler_enabled) {
141144
try {
142145
// Save profiler results to file
143146
std::string output_file = profiler_output_dir + "/pytorch_profiler_trace.json";
144-
profiler->save(output_file);
145-
std::cout << "PyTorch profiler results saved to: " << output_file << std::endl;
147+
profiler_result = torch::profiler::disableProfiler();
148+
if (profiler_result) {
149+
profiler_result->save(output_file);
150+
std::cout << "PyTorch profiler results saved to: " << output_file << std::endl;
151+
}
146152
} catch (const std::exception& e) {
147153
std::cerr << "Warning: Failed to save profiler results: " << e.what() << std::endl;
148154
}
@@ -440,7 +446,12 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
440446
bool do_atom_virial_tensor = atomic;
441447
inputs.push_back(do_atom_virial_tensor);
442448

443-
// Profiling is automatically active when RecordProfile is constructed
449+
// Start profiling if enabled
450+
#ifdef BUILD_PYTORCH
451+
if (profiler_enabled && profiler) {
452+
profiler->step();
453+
}
454+
#endif
444455

445456
c10::Dict<c10::IValue, c10::IValue> outputs =
446457
module.forward(inputs).toGenericDict();

0 commit comments

Comments
 (0)