@@ -737,19 +737,32 @@ void DMatrixToCSR(DMatrix* dmat, std::vector<float>* p_data, std::vector<size_t>
737737#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
738738
739739using CUDAMemoryResource = rmm::mr::cuda_memory_resource;
740+ #if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
741+ using PoolMemoryResource = rmm::mr::pool_memory_resource;
742+ #else
740743using PoolMemoryResource = rmm::mr::pool_memory_resource<CUDAMemoryResource>;
744+ #endif
741745class RMMAllocator {
742746 public:
747+ #if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
748+ std::vector<PoolMemoryResource> pool_mr;
749+ #else
743750 std::vector<std::unique_ptr<CUDAMemoryResource>> cuda_mr;
744751 std::vector<std::unique_ptr<PoolMemoryResource>> pool_mr;
752+ #endif
745753 int n_gpu;
746754 RMMAllocator () : n_gpu(curt::AllVisibleGPUs()) {
747755 int current_device;
748756 CHECK_EQ (cudaGetDevice (¤t_device), cudaSuccess);
749757 for (int i = 0 ; i < n_gpu; ++i) {
750758 CHECK_EQ (cudaSetDevice (i), cudaSuccess);
759+ #if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
760+ CUDAMemoryResource cuda_mr;
761+ pool_mr.push_back (PoolMemoryResource{cuda_mr, 0ul });
762+ #else
751763 cuda_mr.push_back (std::make_unique<CUDAMemoryResource>());
752764 pool_mr.push_back (std::make_unique<PoolMemoryResource>(cuda_mr[i].get (), 0ul ));
765+ #endif
753766 }
754767 CHECK_EQ (cudaSetDevice (current_device), cudaSuccess);
755768 }
@@ -771,7 +784,11 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
771784 LOG (INFO) << " Using RMM memory pool" ;
772785 auto ptr = RMMAllocatorPtr (new RMMAllocator (), DeleteRMMResource);
773786 for (int i = 0 ; i < ptr->n_gpu ; ++i) {
787+ #if RMM_VERSION_MAJOR > 26 || (RMM_VERSION_MAJOR == 26 && RMM_VERSION_MINOR >= 6)
788+ rmm::mr::set_per_device_resource_ref (rmm::cuda_device_id (i), ptr->pool_mr [i]);
789+ #else
774790 rmm::mr::set_per_device_resource (rmm::cuda_device_id (i), ptr->pool_mr [i].get ());
791+ #endif
775792 }
776793 GlobalConfigThreadLocalStore::Get ()->UpdateAllowUnknown (Args{{" use_rmm" , " true" }});
777794 return ptr;
0 commit comments