Skip to content

Commit aa8b3a8

Browse files
committed
fix cmake
1 parent ba8f52e commit aa8b3a8

2 files changed

Lines changed: 24 additions & 18 deletions

File tree

source/api_cc/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ file(GLOB INC_SRC include/*.h ${CMAKE_CURRENT_BINARY_DIR}/version.h)
88
set(libname "${LIB_DEEPMD_CC}")
99
add_library(${libname} SHARED ${LIB_SRC})
1010

11+
if(ENABLE_PYTORCH)
12+
find_package(MPI)
13+
if(MPI_FOUND)
14+
target_link_libraries(${libname} PRIVATE MPI::MPI_CXX)
15+
target_compile_definitions(${libname} PRIVATE USE_MPI)
16+
endif()
17+
endif()
18+
1119
# link: libdeepmd libdeepmd_op libtensorflow_cc libtensorflow_framework
1220
target_link_libraries(${libname} PUBLIC ${LIB_DEEPMD})
1321
if(ENABLE_TENSORFLOW)

source/api_cc/src/DeepPotPT.cc

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212

1313
#ifdef USE_MPI
1414
#include <mpi.h>
15-
#ifdef OMPI_MPI_H
16-
#include <mpi-ext.h>
17-
#endif
1815
#endif
1916

2017
using namespace deepmd;
@@ -202,23 +199,24 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
202199
}
203200

204201
std::vector<int> recvnum_new(nswap, 0);
205-
#ifdef MPI_FOUND
206-
if (lmp_list.world) {
207-
MPI_Comm comm = *static_cast<MPI_Comm*>(lmp_list.world);
208-
const int TAG_BASE = 0x7a31;
209-
for (int s = 0; s < nswap; ++s) {
210-
const int send_to = lmp_list.sendproc[s];
211-
const int recv_from = lmp_list.recvproc[s];
212-
int send_cnt = sendnum_new[s];
213-
int recv_cnt = 0;
214-
MPI_Sendrecv(&send_cnt, 1, MPI_INT, send_to, TAG_BASE + s, &recv_cnt,
215-
1, MPI_INT, recv_from, TAG_BASE + s, comm,
216-
MPI_STATUS_IGNORE);
217-
recvnum_new[s] = recv_cnt;
218-
}
219-
} else
202+
#ifdef USE_MPI
203+
if (lmp_list.world) {
204+
MPI_Comm comm = *static_cast<MPI_Comm*>(lmp_list.world);
205+
const int TAG_BASE = 0x7a31;
206+
for (int s = 0; s < nswap; ++s) {
207+
const int send_to = lmp_list.sendproc[s];
208+
const int recv_from = lmp_list.recvproc[s];
209+
int send_cnt = sendnum_new[s];
210+
int recv_cnt = 0;
211+
MPI_Sendrecv(&send_cnt, 1, MPI_INT, send_to, TAG_BASE + s,
212+
&recv_cnt, 1, MPI_INT, recv_from, TAG_BASE + s,
213+
comm, MPI_STATUS_IGNORE);
214+
recvnum_new[s] = recv_cnt;
215+
}
216+
} else
220217
#endif
221218
{
219+
// need check
222220
for (int s = 0; s < nswap; ++s) {
223221
recvnum_new[s] = sendnum_new[s];
224222
}

0 commit comments

Comments
 (0)