Skip to content

Commit 101a901

Browse files
Copilotcaic99
andcommitted
Add MPI rank support for PyTorch profiler output files
Co-authored-by: caic99 <78061359+caic99@users.noreply.github.com>
1 parent a879955 commit 101a901

6 files changed

Lines changed: 65 additions & 4 deletions

File tree

PYTORCH_PROFILER_INTEGRATION.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ export DP_PYTORCH_PROFILER_OUTPUT_DIR=./profiler_results
1414

1515
3. Check for profiler output in the specified directory:
1616
```bash
17+
# For single-rank or non-MPI usage
1718
ls -la ./profiler_results/pytorch_profiler_trace.json
19+
20+
# For MPI usage, each rank gets its own file
21+
ls -la ./profiler_results/pytorch_profiler_trace_rank*.json
1822
```
1923

2024
## Environment Variables
@@ -28,5 +32,13 @@ The profiler uses PyTorch's modern `torch::profiler` API and automatically:
2832
- Creates the output directory if it doesn't exist
2933
- Profiles all forward pass operations in DeepPotPT and DeepSpinPT
3034
- Saves profiling results to a JSON file when the object is destroyed
35+
- Automatically includes MPI rank in filename when MPI is available and initialized
36+
37+
## Output Files
38+
39+
- **Single-rank or non-MPI usage**: `pytorch_profiler_trace.json`
40+
- **MPI usage**: `pytorch_profiler_trace_rank{rank}.json` (e.g., `pytorch_profiler_trace_rank0.json`, `pytorch_profiler_trace_rank1.json`)
41+
42+
This ensures that each MPI rank saves its profiling data to a separate file, preventing conflicts in multi-rank simulations.
3143

3244
This is intended for development and debugging purposes.

source/api_cc/include/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ void get_env_nthreads(int& num_intra_nthreads, int& num_inter_nthreads);
172172
**/
173173
void get_env_pytorch_profiler(bool& enable_profiler, std::string& output_dir);
174174

175+
/**
176+
* @brief Get MPI rank if MPI is available and initialized, otherwise return 0.
177+
* @return The MPI rank or 0 if MPI is not available/initialized.
178+
**/
179+
int get_mpi_rank();
180+
175181
/**
176182
* @brief Dynamically load OP library. This should be called before loading
177183
* graphs.

source/api_cc/src/DeepPotPT.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,16 @@ DeepPotPT::~DeepPotPT() {
142142
#ifdef BUILD_PYTORCH
143143
if (profiler_enabled) {
144144
try {
145-
// Save profiler results to file
146-
std::string output_file = profiler_output_dir + "/pytorch_profiler_trace.json";
145+
// Save profiler results to file with MPI rank if available
146+
int rank = get_mpi_rank();
147+
std::string output_file;
148+
if (rank >= 0) {
149+
// MPI is available and initialized, include rank in filename
150+
output_file = profiler_output_dir + "/pytorch_profiler_trace_rank" + std::to_string(rank) + ".json";
151+
} else {
152+
// MPI not available or not initialized, use original filename
153+
output_file = profiler_output_dir + "/pytorch_profiler_trace.json";
154+
}
147155
profiler_result = torch::profiler::disableProfiler();
148156
if (profiler_result) {
149157
profiler_result->save(output_file);

source/api_cc/src/DeepSpinPT.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,16 @@ DeepSpinPT::~DeepSpinPT() {
142142
#ifdef BUILD_PYTORCH
143143
if (profiler_enabled) {
144144
try {
145-
// Save profiler results to file
146-
std::string output_file = profiler_output_dir + "/pytorch_profiler_trace.json";
145+
// Save profiler results to file with MPI rank if available
146+
int rank = get_mpi_rank();
147+
std::string output_file;
148+
if (rank >= 0) {
149+
// MPI is available and initialized, include rank in filename
150+
output_file = profiler_output_dir + "/pytorch_profiler_trace_rank" + std::to_string(rank) + ".json";
151+
} else {
152+
// MPI not available or not initialized, use original filename
153+
output_file = profiler_output_dir + "/pytorch_profiler_trace.json";
154+
}
147155
profiler_result = torch::profiler::disableProfiler();
148156
if (profiler_result) {
149157
profiler_result->save(output_file);

source/api_cc/src/common.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
#include <sstream>
99
#include <string>
1010

11+
// Try to include MPI if available - this will be a no-op if MPI is not available
12+
#ifdef __has_include
13+
#if __has_include(<mpi.h>)
14+
#include <mpi.h>
15+
#endif
16+
#endif
17+
1118
#include "AtomMap.h"
1219
#include "device.h"
1320
#if defined(_WIN32)
@@ -398,6 +405,20 @@ void deepmd::get_env_pytorch_profiler(bool& enable_profiler, std::string& output
398405
}
399406
}
400407

408+
int deepmd::get_mpi_rank() {
409+
int rank = -1; // Use -1 to indicate MPI not available/initialized
410+
// Try to detect MPI at runtime
411+
#ifdef MPI_H
412+
int initialized = 0;
413+
if (MPI_Initialized(&initialized) == MPI_SUCCESS && initialized) {
414+
if (MPI_Comm_rank(MPI_COMM_WORLD, &rank) != MPI_SUCCESS) {
415+
rank = -1; // fallback to -1 if MPI_Comm_rank fails
416+
}
417+
}
418+
#endif
419+
return rank;
420+
}
421+
401422
static inline void _load_library_path(std::string dso_path) {
402423
#if defined(_WIN32)
403424
void* dso_handle = LoadLibrary(dso_path.c_str());

source/api_cc/tests/test_pytorch_profiler.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,10 @@ TEST_F(TestPyTorchProfiler, test_profiler_disabled_with_zero) {
7373

7474
EXPECT_FALSE(enable_profiler);
7575
EXPECT_EQ(output_dir, "./profiler_output");
76+
}
77+
78+
TEST_F(TestPyTorchProfiler, test_mpi_rank_detection) {
79+
// Test that MPI rank detection returns valid rank (-1 when MPI not initialized, >= 0 when initialized)
80+
int rank = deepmd::get_mpi_rank();
81+
EXPECT_GE(rank, -1); // Rank should be -1 (not available) or >= 0 (valid rank)
7682
}

0 commit comments

Comments
 (0)