Skip to content

Commit 1c7f0e8

Browse files
committed
Replace cub::Max() with cuda::maximum<> in kernel reductions
1 parent 6dc9b51 commit 1c7f0e8

1 file changed

Lines changed: 8 additions & 32 deletions

File tree

csrc/pythonInterface.cpp

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -614,40 +614,16 @@ void* cget_managed_ptr(size_t bytes) {
614614
return ptr;
615615
}
616616

617-
#include <cuda_runtime_api.h>
618-
#ifndef CUDART_VERSION
619-
#define CUDART_VERSION 0
620-
#endif
621-
622-
// Unified helper: CUDA13+ uses cudaMemLocation; older CUDA/HIP keeps int device
623-
static inline cudaError_t bnb_prefetch_to(void* ptr, size_t bytes, int device, cudaStream_t stream) {
624-
#if defined(BUILD_CUDA) && !defined(BUILD_HIP) && (CUDART_VERSION >= 13000)
625-
cudaMemLocation loc{};
626-
if (device == cudaCpuDeviceId) {
627-
loc.type = cudaMemLocationTypeHost;
628-
loc.id = 0;
629-
} else {
630-
loc.type = cudaMemLocationTypeDevice;
631-
loc.id = device;
632-
}
633-
return cudaMemPrefetchAsync(ptr, bytes, loc, stream);
634-
#else
635-
// Older CUDA or HIP path (your BUILD_HIP macro maps cudaMemPrefetchAsync -> hipMemPrefetchAsync)
636-
return cudaMemPrefetchAsync(ptr, bytes, device, stream);
637-
#endif
638-
}
639-
640617
void cprefetch(void* ptr, size_t bytes, int device) {
641-
// Only check the device attribute when prefetching to a device
642-
if (device != cudaCpuDeviceId) {
643-
int hasPrefetch = 0;
644-
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(
645-
&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // ~40ns
646-
if (hasPrefetch == 0)
647-
return;
648-
}
649618

650-
CUDA_CHECK_RETURN(bnb_prefetch_to(ptr, bytes, device, /*stream=*/0));
619+
int hasPrefetch = 0;
620+
CUDA_CHECK_RETURN(
621+
cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)
622+
); // 40ns overhead
623+
if (hasPrefetch == 0)
624+
return;
625+
626+
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
651627
CUDA_CHECK_RETURN(cudaPeekAtLastError());
652628
}
653629

0 commit comments

Comments
 (0)