Skip to content

Commit 0ca67ae

Browse files
caic99pre-commit-ci[bot]njzjz
authored andcommitted
feat: add PyTorch profiler support to LAMMPS MD (deepmodeling#4969)
This pull request adds support for enabling the PyTorch profiler in the `DeepPotPT` backend via an environment variable, making it easier to profile and debug performance issues. The profiler can be activated by setting the `DP_PROFILER` environment variable, and will save profiling results to a JSON file when the object is destroyed. Fixes deepmodeling#4431 **Profiler integration:** * Added a `profiler_enabled` flag and `profiler_file` string to the `DeepPotPT` class to track profiler state and output location (`DeepPotPT.h`). * In the `init` method, added logic to check for the `DP_PROFILER` environment variable. If set, the PyTorch profiler is configured and enabled, with output written to a file named according to the environment variable and GPU ID (`DeepPotPT.cc`). * On destruction of `DeepPotPT`, if profiling was enabled, the profiler is properly disabled and results are saved to the specified file (`DeepPotPT.cc`). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional runtime profiling via DP_PROFILER environment variable. Produces per-process JSON traces (CPU-only or CPU+GPU) and saves them automatically on shutdown; console messages report profiler status and output filename. * **Refactor** * Improved device selection: auto-detects CUDA, selects a GPU per process when available and falls back to CPU; clearer startup diagnostics with no public API changes. * **Documentation** * Added DP_PROFILER usage, filename rules, examples, and operational tips to the C++ interface docs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <amoycaic@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <njzjz@qq.com>
1 parent 25ec5a6 commit 0ca67ae

3 files changed

Lines changed: 79 additions & 7 deletions

File tree

doc/env.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,37 @@ These environment variables also apply to third-party programs using the C++ int
8888
**Type**: List of paths, split by `:` on Unix and `;` on Windows
8989

9090
List of customized OP plugin libraries to load, such as `/path/to/plugin1.so:/path/to/plugin2.so` on Linux and `/path/to/plugin1.dll;/path/to/plugin2.dll` on Windows.
91+
:::
92+
93+
:::{envvar} DP_PROFILER
94+
95+
{{ pytorch_icon }} Enable the built-in PyTorch Kineto profiler for the PyTorch C++ (inference) backend.
96+
97+
**Type**: string (output file stem)
98+
99+
**Default**: unset (disabled)
100+
101+
When set to a non-empty value, profiling is enabled for the lifetime of the loaded PyTorch model (e.g. during LAMMPS runs). A JSON trace file is created on finish. The final file name is constructed as:
102+
103+
- `<ENV_VALUE>_gpu<ID>.json` if running on GPU
104+
- `<ENV_VALUE>.json` if running on CPU
105+
106+
The trace can be examined with [Chrome trace viewer](https://ui.perfetto.dev/) (alternatively chrome://tracing). It includes:
107+
108+
- CPU operator activities
109+
- CUDA activities (if available)
110+
111+
Example:
112+
113+
```bash
114+
export DP_PROFILER=result
115+
mpirun -np 4 lmp -in in.lammps
116+
# Produces result_gpuX.json, where X is the GPU id used by each MPI rank.
117+
```
118+
119+
Tips:
120+
121+
- Large runs can generate sizable JSON files; consider limiting numbers of MD steps, like 20.
122+
- Currently this feature only supports single process, or multi-process runs where each process uses a distinct GPU on the same node.
91123

92124
:::

source/api_cc/include/DeepPotPT.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ class DeepPotPT : public DeepPotBackend {
340340
at::Tensor firstneigh_tensor;
341341
c10::optional<torch::Tensor> mapping_tensor;
342342
torch::Dict<std::string, torch::Tensor> comm_dict;
343+
bool profiler_enabled{false};
344+
std::string profiler_file;
343345
/**
344346
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
345347
* @param[in] f The function to run.

source/api_cc/src/DeepPotPT.cc

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#ifdef BUILD_PYTORCH
33
#include "DeepPotPT.h"
44

5+
#include <torch/csrc/autograd/profiler.h>
56
#include <torch/csrc/jit/runtime/jit_exception.h>
67

78
#include <cstdint>
@@ -69,13 +70,9 @@ void DeepPotPT::init(const std::string& model,
6970
}
7071
deepmd::load_op_library();
7172
int gpu_num = torch::cuda::device_count();
72-
if (gpu_num > 0) {
73-
gpu_id = gpu_rank % gpu_num;
74-
} else {
75-
gpu_id = 0;
76-
}
77-
torch::Device device(torch::kCUDA, gpu_id);
73+
gpu_id = (gpu_num > 0) ? (gpu_rank % gpu_num) : 0;
7874
gpu_enabled = torch::cuda::is_available();
75+
torch::Device device(torch::kCUDA, gpu_id);
7976
if (!gpu_enabled) {
8077
device = torch::Device(torch::kCPU);
8178
std::cout << "load model from: " << model << " to cpu " << std::endl;
@@ -86,6 +83,37 @@ void DeepPotPT::init(const std::string& model,
8683
std::cout << "load model from: " << model << " to gpu " << gpu_id
8784
<< std::endl;
8885
}
86+
87+
// Configure PyTorch profiler
88+
const char* env_profiler = std::getenv("DP_PROFILER");
89+
if (env_profiler && *env_profiler) {
90+
using torch::profiler::impl::ActivityType;
91+
using torch::profiler::impl::ExperimentalConfig;
92+
using torch::profiler::impl::ProfilerConfig;
93+
using torch::profiler::impl::ProfilerState;
94+
std::set<ActivityType> activities{ActivityType::CPU};
95+
if (gpu_enabled) {
96+
activities.insert(ActivityType::CUDA);
97+
}
98+
profiler_file = std::string(env_profiler);
99+
if (gpu_enabled) {
100+
profiler_file += "_gpu" + std::to_string(gpu_id);
101+
}
102+
profiler_file += ".json";
103+
ExperimentalConfig exp_cfg;
104+
ProfilerConfig cfg(ProfilerState::KINETO,
105+
false, // report_input_shapes
106+
false, // profile_memory
107+
true, // with_stack
108+
false, // with_flops
109+
true, // with_modules
110+
exp_cfg);
111+
torch::autograd::profiler::prepareProfiler(cfg, activities);
112+
torch::autograd::profiler::enableProfiler(cfg, activities);
113+
std::cout << "PyTorch profiler enabled, output file: " << profiler_file
114+
<< std::endl;
115+
profiler_enabled = true;
116+
}
89117
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
90118
module = torch::jit::load(model, device, metadata);
91119
module.eval();
@@ -119,7 +147,17 @@ void DeepPotPT::init(const std::string& model,
119147
aparam_nall = module.run_method("is_aparam_nall").toBool();
120148
inited = true;
121149
}
122-
DeepPotPT::~DeepPotPT() {}
150+
151+
DeepPotPT::~DeepPotPT() {
152+
if (profiler_enabled) {
153+
auto result = torch::autograd::profiler::disableProfiler();
154+
if (result) {
155+
result->save(profiler_file);
156+
}
157+
std::cout << "PyTorch profiler result saved to " << profiler_file
158+
<< std::endl;
159+
}
160+
}
123161

124162
template <typename VALUETYPE, typename ENERGYVTYPE>
125163
void DeepPotPT::compute(ENERGYVTYPE& ener,

0 commit comments

Comments
 (0)