@@ -614,40 +614,16 @@ void* cget_managed_ptr(size_t bytes) {
614614 return ptr;
615615}
616616
617- #include < cuda_runtime_api.h>
618- #ifndef CUDART_VERSION
619- #define CUDART_VERSION 0
620- #endif
621-
622- // Unified helper: CUDA13+ uses cudaMemLocation; older CUDA/HIP keeps int device
623- static inline cudaError_t bnb_prefetch_to (void * ptr, size_t bytes, int device, cudaStream_t stream) {
624- #if defined(BUILD_CUDA) && !defined(BUILD_HIP) && (CUDART_VERSION >= 13000)
625- cudaMemLocation loc{};
626- if (device == cudaCpuDeviceId) {
627- loc.type = cudaMemLocationTypeHost;
628- loc.id = 0 ;
629- } else {
630- loc.type = cudaMemLocationTypeDevice;
631- loc.id = device;
632- }
633- return cudaMemPrefetchAsync (ptr, bytes, loc, stream);
634- #else
635- // Older CUDA or HIP path (your BUILD_HIP macro maps cudaMemPrefetchAsync -> hipMemPrefetchAsync)
636- return cudaMemPrefetchAsync (ptr, bytes, device, stream);
637- #endif
638- }
639-
640617void cprefetch (void * ptr, size_t bytes, int device) {
641- // Only check the device attribute when prefetching to a device
642- if (device != cudaCpuDeviceId) {
643- int hasPrefetch = 0 ;
644- CUDA_CHECK_RETURN (cudaDeviceGetAttribute (
645- &hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // ~40ns
646- if (hasPrefetch == 0 )
647- return ;
648- }
649618
650- CUDA_CHECK_RETURN (bnb_prefetch_to (ptr, bytes, device, /* stream=*/ 0 ));
619+ int hasPrefetch = 0 ;
620+ CUDA_CHECK_RETURN (
621+ cudaDeviceGetAttribute (&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)
622+ ); // 40ns overhead
623+ if (hasPrefetch == 0 )
624+ return ;
625+
626+ CUDA_CHECK_RETURN (cudaMemPrefetchAsync (ptr, bytes, device, 0 ));
651627 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
652628}
653629
0 commit comments