Skip to content

Commit 402d94d

Browse files
committed
Applied cudaMemPrefetchAsync API change for CUDA 13
1 parent 3f582e7 commit 402d94d

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

pfsimulator/amps/oas3/amps_allreduce.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,28 @@ int amps_AllReduce(amps_Comm comm, amps_Invoice invoice, MPI_Op operation)
150150

151151
if (cudaGetLastError() == cudaSuccess && attributes.type > 1)
152152
{
153+
#if CUDART_VERSION >= 13000
154+
int deviceIndex;
155+
CUDA_ERRCHK(cudaGetDevice(&deviceIndex));
156+
struct cudaMemLocation location = {};
157+
location.type = cudaMemLocationTypeHost;
158+
location.id = deviceIndex;
159+
#endif
153160
if (stride == 1)
161+
#if CUDART_VERSION >= 13000
162+
CUDA_ERRCHK(cudaMemPrefetchAsync(data, (size_t)len * element_size, location, 0, 0));
163+
#else
154164
CUDA_ERRCHK(cudaMemPrefetchAsync(data, (size_t)len * element_size, cudaCpuDeviceId, 0));
165+
#endif
155166
else
156167
for (ptr_src = data;
157168
ptr_src < data + len * stride * element_size;
158169
ptr_src += stride * element_size)
170+
#if CUDART_VERSION >= 13000
171+
CUDA_ERRCHK(cudaMemPrefetchAsync(data, (size_t)len * element_size, location, 0, 0));
172+
#else
159173
CUDA_ERRCHK(cudaMemPrefetchAsync(ptr_src, (size_t)element_size, cudaCpuDeviceId, 0));
160-
174+
#endif
161175
CUDA_ERRCHK(cudaStreamSynchronize(0));
162176
}
163177
#endif

0 commit comments

Comments
 (0)