Skip to content

Commit 23484e4

Browse files
authored
Merge pull request QMCPACK#5507 from ye-luo/refactor-multispline
Refactor MultiSpline and MultiSplineOffload
2 parents b0a4d1c + aa8c507 commit 23484e4

13 files changed

Lines changed: 136 additions & 56 deletions

File tree

src/Estimators/tests/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ if(USE_OBJECT_TARGET)
8383
test_estimators_help
8484
qmcham
8585
qmcwfs
86-
qmcparticle
8786
qmcwfs_omptarget
87+
spline2
88+
spline2_omptarget
89+
qmcparticle
8890
qmcparticle_omptarget
8991
qmcutil
9092
platform_omptarget_LA

src/QMCApp/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ if(USE_OBJECT_TARGET)
5151
qmcestimators
5252
qmcham
5353
qmcwfs
54-
qmcparticle
5554
qmcwfs_omptarget
55+
spline2
56+
spline2_omptarget
57+
qmcparticle
5658
qmcparticle_omptarget
5759
qmcutil
5860
platform_omptarget_LA)

src/QMCDrivers/tests/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ if(USE_OBJECT_TARGET)
5050
qmcestimators
5151
qmcham
5252
qmcwfs
53-
qmcparticle
5453
qmcwfs_omptarget
54+
spline2
55+
spline2_omptarget
56+
qmcparticle
5557
qmcparticle_omptarget
5658
qmcutil
5759
platform_omptarget_LA)
@@ -92,8 +94,10 @@ if(USE_OBJECT_TARGET)
9294
qmcestimators
9395
qmcham
9496
qmcwfs
95-
qmcparticle
9697
qmcwfs_omptarget
98+
spline2
99+
spline2_omptarget
100+
qmcparticle
97101
qmcparticle_omptarget
98102
qmcutil
99103
platform_omptarget_LA)

src/QMCHamiltonians/tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ foreach(CATEGORY coulomb force ham ewald2d)
7070

7171
target_link_libraries(${UTEST_EXE} catch_main qmcham utilities_for_test)
7272
if(USE_OBJECT_TARGET)
73-
target_link_libraries(${UTEST_EXE} qmcwfs qmcparticle qmcwfs_omptarget qmcparticle_omptarget qmcutil platform_omptarget_LA)
73+
target_link_libraries(${UTEST_EXE} qmcwfs qmcwfs_omptarget qmcparticle qmcparticle_omptarget spline2 spline2_omptarget qmcutil platform_omptarget_LA)
7474
endif()
7575

7676
add_unit_test(${UTEST_NAME} 1 1 $<TARGET_FILE:${UTEST_EXE}>)

src/QMCWaveFunctions/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,13 @@ else(USE_OBJECT_TARGET)
149149
endif(USE_OBJECT_TARGET)
150150

151151
target_include_directories(qmcwfs PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
152-
target_link_libraries(qmcwfs PUBLIC qmcutil qmcparticle platform_runtime platform_LA)
152+
target_link_libraries(qmcwfs PUBLIC spline2 qmcutil qmcparticle platform_runtime platform_LA)
153153
target_link_libraries(qmcwfs PRIVATE einspline Math::FFTW3)
154154

155155
add_library(qmcwfs_omptarget OBJECT ${JASTROW_OMPTARGET_SRCS} ${FERMION_OMPTARGET_SRCS})
156156

157157
target_include_directories(qmcwfs_omptarget PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
158-
target_link_libraries(qmcwfs_omptarget PUBLIC qmcutil qmcparticle containers platform_LA)
158+
target_link_libraries(qmcwfs_omptarget PUBLIC spline2 qmcutil qmcparticle containers platform_LA)
159159

160160
target_link_libraries(qmcwfs PUBLIC qmcwfs_omptarget)
161161

src/QMCWaveFunctions/tests/CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,13 @@ foreach(CATEGORY common trialwf sposet jastrow determinant)
181181
${UTEST_EXE}
182182
catch_main
183183
qmcwfs
184-
qmcwfs_omptarget
185-
qmcparticle_omptarget
186184
platform_LA
187185
platform_runtime
188186
sposets_for_testing
189187
utilities_for_test
190188
container_testing)
191189
if(USE_OBJECT_TARGET)
192-
target_link_libraries(${UTEST_EXE} qmcparticle qmcparticle_omptarget qmcwfs_omptarget qmcutil platform_omptarget_LA)
190+
target_link_libraries(${UTEST_EXE} qmcparticle qmcparticle_omptarget qmcwfs_omptarget spline2 spline2_omptarget qmcutil platform_omptarget_LA)
193191
endif()
194192

195193
add_unit_test(${UTEST_NAME} 1 1 $<TARGET_FILE:${UTEST_EXE}>)

src/Sandbox/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ target_link_libraries(sandbox_helper PUBLIC qmcparticle)
99

1010
foreach(p ${ESTEST})
1111
add_executable(${p} ${p}.cpp)
12-
target_link_libraries(${p} einspline qmcparticle sandbox_helper)
12+
target_link_libraries(${p} spline2 qmcparticle sandbox_helper)
1313
add_unit_test(sandbox_${p} 1 3 $<TARGET_FILE:${p}>)
1414
endforeach(p ${ESTEST})

src/spline2/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
#// File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
1010
#//////////////////////////////////////////////////////////////////////////////////////
1111

12+
13+
add_library(spline2_omptarget OBJECT MultiBsplineOffload.cpp)
14+
target_link_libraries(spline2_omptarget PUBLIC einspline platform_runtime)
15+
16+
if(USE_OBJECT_TARGET)
17+
add_library(spline2 OBJECT MultiBspline.cpp)
18+
else()
19+
add_library(spline2 MultiBspline.cpp)
20+
endif()
21+
target_link_libraries(spline2 PUBLIC spline2_omptarget)
22+
1223
if(BUILD_UNIT_TESTS)
1324
add_subdirectory(tests)
1425
endif()

src/spline2/MultiBspline.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//////////////////////////////////////////////////////////////////////////////////////
2+
// This file is distributed under the University of Illinois/NCSA Open Source License.
3+
// See LICENSE file in top directory for details.
4+
//
5+
// Copyright (c) 2025 QMCPACK developers.
6+
//
7+
// File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8+
//
9+
// File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10+
//////////////////////////////////////////////////////////////////////////////////////
11+
12+
13+
#include "MultiBspline.hpp"
14+
15+
namespace qmcplusplus
16+
{
17+
template<typename T>
18+
typename MultiBsplineBase<T>::SplineType* MultiBspline<T>::createImpl(const Ugrid grid[3],
19+
const typename Base::BoundaryCondition bc[3],
20+
int num_splines)
21+
{
22+
static_assert(std::is_same<T, typename Alloc::value_type>::value, "MultiBspline and Alloc data types must agree!");
23+
if (getAlignedSize<T, Alloc::alignment>(num_splines) != num_splines)
24+
throw std::runtime_error("When creating the data space of MultiBspline, num_splines must be padded!\n");
25+
return myAllocator.allocateMultiBspline(grid[0], grid[1], grid[2], bc[0], bc[1], bc[2], num_splines);
26+
}
27+
28+
template<typename T>
29+
MultiBspline<T>::MultiBspline() = default;
30+
31+
template<typename T>
32+
MultiBspline<T>::~MultiBspline()
33+
{
34+
if (Base::spline_m != nullptr)
35+
myAllocator.destroy(Base::spline_m);
36+
}
37+
38+
template class MultiBspline<float>;
39+
template class MultiBspline<double>;
40+
} // namespace qmcplusplus

src/spline2/MultiBspline.hpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,15 @@ class MultiBspline : public MultiBsplineBase<T>
3939

4040
typename Base::SplineType* createImpl(const Ugrid grid[3],
4141
const typename Base::BoundaryCondition bc[3],
42-
int num_splines) override
43-
{
44-
static_assert(std::is_same<T, typename Alloc::value_type>::value, "MultiBspline and Alloc data types must agree!");
45-
if (getAlignedSize<T, Alloc::alignment>(num_splines) != num_splines)
46-
throw std::runtime_error("When creating the data space of MultiBspline, num_splines must be padded!\n");
47-
return myAllocator.allocateMultiBspline(grid[0], grid[1], grid[2], bc[0], bc[1], bc[2], num_splines);
48-
}
42+
int num_splines) override;
4943

5044
public:
51-
MultiBspline() = default;
52-
53-
~MultiBspline() override
54-
{
55-
if (Base::spline_m != nullptr)
56-
myAllocator.destroy(Base::spline_m);
57-
}
58-
59-
void finalize() override {}
45+
MultiBspline();
46+
~MultiBspline() override;
6047
};
6148

49+
extern template class MultiBspline<float>;
50+
extern template class MultiBspline<double>;
6251
} // namespace qmcplusplus
6352

6453
#endif

0 commit comments

Comments
 (0)