Skip to content

Commit e0d3dfd

Browse files
authored
Update test helpers for RMM >= 26.06 API changes (#12141)
1 parent e66c5f3 commit e0d3dfd

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

tests/cpp/helpers.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,19 +737,32 @@ void DMatrixToCSR(DMatrix* dmat, std::vector<float>* p_data, std::vector<size_t>
737737
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
738738

739739
using CUDAMemoryResource = rmm::mr::cuda_memory_resource;
740+
#if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
741+
using PoolMemoryResource = rmm::mr::pool_memory_resource;
742+
#else
740743
using PoolMemoryResource = rmm::mr::pool_memory_resource<CUDAMemoryResource>;
744+
#endif
741745
class RMMAllocator {
742746
public:
747+
#if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
748+
std::vector<PoolMemoryResource> pool_mr;
749+
#else
743750
std::vector<std::unique_ptr<CUDAMemoryResource>> cuda_mr;
744751
std::vector<std::unique_ptr<PoolMemoryResource>> pool_mr;
752+
#endif
745753
int n_gpu;
746754
RMMAllocator() : n_gpu(curt::AllVisibleGPUs()) {
747755
int current_device;
748756
CHECK_EQ(cudaGetDevice(&current_device), cudaSuccess);
749757
for (int i = 0; i < n_gpu; ++i) {
750758
CHECK_EQ(cudaSetDevice(i), cudaSuccess);
759+
#if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
760+
CUDAMemoryResource cuda_mr;
761+
pool_mr.push_back(PoolMemoryResource{cuda_mr, 0ul});
762+
#else
751763
cuda_mr.push_back(std::make_unique<CUDAMemoryResource>());
752764
pool_mr.push_back(std::make_unique<PoolMemoryResource>(cuda_mr[i].get(), 0ul));
765+
#endif
753766
}
754767
CHECK_EQ(cudaSetDevice(current_device), cudaSuccess);
755768
}
@@ -771,7 +784,11 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
771784
LOG(INFO) << "Using RMM memory pool";
772785
auto ptr = RMMAllocatorPtr(new RMMAllocator(), DeleteRMMResource);
773786
for (int i = 0; i < ptr->n_gpu; ++i) {
787+
#if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
788+
rmm::mr::set_per_device_resource_ref(rmm::cuda_device_id(i), ptr->pool_mr[i]);
789+
#else
774790
rmm::mr::set_per_device_resource(rmm::cuda_device_id(i), ptr->pool_mr[i].get());
791+
#endif
775792
}
776793
GlobalConfigThreadLocalStore::Get()->UpdateAllowUnknown(Args{{"use_rmm", "true"}});
777794
return ptr;

0 commit comments

Comments
 (0)