@@ -119,8 +119,11 @@ void DeepPotPT::init(const std::string& model,
119119 std::system (mkdir_cmd.c_str ());
120120
121121 std::cout << " PyTorch profiler enabled. Output directory: " << profiler_output_dir << std::endl;
122- // Initialize profiler with default configuration
123- profiler = std::make_unique<torch::autograd::profiler::RecordProfile>();
122+ // Start profiling using new API
123+ torch::profiler::profile ({
124+ torch::profiler::ProfilerActivity::CPU,
125+ torch::profiler::ProfilerActivity::CUDA,
126+ }, true , true , false ); // record_shapes, profile_memory, with_stack
124127#else
125128 std::cerr << " Warning: PyTorch profiler requested but BUILD_PYTORCH not defined" << std::endl;
126129#endif
@@ -137,12 +140,15 @@ void DeepPotPT::init(const std::string& model,
137140}
138141DeepPotPT::~DeepPotPT () {
139142#ifdef BUILD_PYTORCH
140- if (profiler_enabled && profiler ) {
143+ if (profiler_enabled) {
141144 try {
142145 // Save profiler results to file
143146 std::string output_file = profiler_output_dir + " /pytorch_profiler_trace.json" ;
144- profiler->save (output_file);
145- std::cout << " PyTorch profiler results saved to: " << output_file << std::endl;
147+ profiler_result = torch::profiler::disableProfiler ();
148+ if (profiler_result) {
149+ profiler_result->save (output_file);
150+ std::cout << " PyTorch profiler results saved to: " << output_file << std::endl;
151+ }
146152 } catch (const std::exception& e) {
147153 std::cerr << " Warning: Failed to save profiler results: " << e.what () << std::endl;
148154 }
@@ -264,7 +270,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
264270 .to (device);
265271 }
266272
267- // Profiling is automatically active when RecordProfile is constructed
273+ // Start profiling if enabled
274+ #ifdef BUILD_PYTORCH
275+ if (profiler_enabled && profiler) {
276+ profiler->step ();
277+ }
278+ #endif
268279
269280 c10::Dict<c10::IValue, c10::IValue> outputs =
270281 (do_message_passing)
@@ -416,7 +427,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
416427 bool do_atom_virial_tensor = atomic;
417428 inputs.push_back (do_atom_virial_tensor);
418429
419- // Profiling is automatically active when RecordProfile is constructed
430+ // Start profiling if enabled
431+ #ifdef BUILD_PYTORCH
432+ if (profiler_enabled && profiler) {
433+ profiler->step ();
434+ }
435+ #endif
420436
421437 c10::Dict<c10::IValue, c10::IValue> outputs =
422438 module .forward (inputs).toGenericDict ();
0 commit comments