|
| 1 | +// HIP compatibility shim: maps <cuda_runtime.h> to HIP equivalents. |
| 2 | +// Included transparently when building with -I hip_compat on ROCm. |
| 3 | +#pragma once |
| 4 | + |
| 5 | +// hip/hip_runtime.h requires exactly one of __HIP_PLATFORM_AMD__ or |
| 6 | +// __HIP_PLATFORM_NVIDIA__ to be defined. hipcc sets it automatically; |
| 7 | +// g++ (used for plain CXX sources in the dflash build) does not. |
| 8 | +#if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) |
| 9 | +# define __HIP_PLATFORM_AMD__ |
| 10 | +#endif |
| 11 | + |
| 12 | +#include <hip/hip_runtime.h> |
| 13 | +#include <hip/hip_runtime_api.h> |
| 14 | + |
| 15 | +// Type aliases |
| 16 | +using cudaStream_t = hipStream_t; |
| 17 | +using cudaEvent_t = hipEvent_t; |
| 18 | +using cudaError_t = hipError_t; |
| 19 | +using cudaMemcpyKind = hipMemcpyKind; |
| 20 | +using cudaDeviceProp = hipDeviceProp_t; |
| 21 | + |
| 22 | +// Memcpy kind constants |
| 23 | +#define cudaMemcpyHostToHost hipMemcpyHostToHost |
| 24 | +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice |
| 25 | +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost |
| 26 | +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice |
| 27 | +#define cudaMemcpyDefault hipMemcpyDefault |
| 28 | + |
| 29 | +// Error codes |
| 30 | +#define cudaSuccess hipSuccess |
| 31 | +#define cudaErrorInvalidValue hipErrorInvalidValue |
| 32 | + |
| 33 | +// Memory functions |
| 34 | +#define cudaMalloc hipMalloc |
| 35 | +#define cudaMallocHost hipHostMalloc |
| 36 | +#define cudaFree hipFree |
| 37 | +#define cudaFreeHost hipHostFree |
| 38 | +#define cudaMemcpy hipMemcpy |
| 39 | +#define cudaMemcpyAsync hipMemcpyAsync |
| 40 | +#define cudaMemcpy2DAsync hipMemcpy2DAsync |
| 41 | +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync |
| 42 | +#define cudaMemset hipMemset |
| 43 | +#define cudaMemsetAsync hipMemsetAsync |
| 44 | + |
| 45 | +// Stream functions |
| 46 | +#define cudaStreamCreate hipStreamCreate |
| 47 | +#define cudaStreamDestroy hipStreamDestroy |
| 48 | +#define cudaStreamSynchronize hipStreamSynchronize |
| 49 | +#define cudaStreamDefault hipStreamDefault |
| 50 | +#define cudaStreamNonBlocking hipStreamNonBlocking |
| 51 | + |
| 52 | +// Device functions |
| 53 | +#define cudaGetDevice hipGetDevice |
| 54 | +#define cudaSetDevice hipSetDevice |
| 55 | +#define cudaDeviceSynchronize hipDeviceSynchronize |
| 56 | +#define cudaGetDeviceProperties hipGetDeviceProperties |
| 57 | +#define cudaDeviceReset hipDeviceReset |
| 58 | + |
| 59 | +// Event functions |
| 60 | +#define cudaEventCreate hipEventCreate |
| 61 | +#define cudaEventDestroy hipEventDestroy |
| 62 | +#define cudaEventRecord hipEventRecord |
| 63 | +#define cudaEventSynchronize hipEventSynchronize |
| 64 | +#define cudaEventElapsedTime hipEventElapsedTime |
| 65 | +#define cudaEventCreateWithFlags hipEventCreateWithFlags |
| 66 | +#define cudaEventDisableTiming hipEventDisableTiming |
| 67 | + |
| 68 | +// Kernel attribute |
| 69 | +#define cudaFuncSetAttribute hipFuncSetAttribute |
| 70 | +#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize |
| 71 | + |
| 72 | +// Error checking |
| 73 | +#define cudaGetLastError hipGetLastError |
| 74 | +#define cudaGetErrorString hipGetErrorString |
| 75 | + |
| 76 | +// Launch bounds |
| 77 | +#define __launch_bounds__ __launch_bounds__ |
| 78 | + |
| 79 | +// Stream capture status (added CUDA 10.0 — ROCm compat headers may omit this) |
| 80 | +#define cudaStreamCaptureStatus hipStreamCaptureStatus |
| 81 | +#define cudaStreamCaptureStatusNone hipStreamCaptureStatusNone |
| 82 | +#define cudaStreamCaptureStatusActive hipStreamCaptureStatusActive |
| 83 | +#define cudaStreamCaptureStatusInvalidated hipStreamCaptureStatusInvalidated |
| 84 | +#define cudaStreamIsCapturing hipStreamIsCapturing |
| 85 | + |
| 86 | +// Peer device access |
| 87 | +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer |
| 88 | +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess |
| 89 | +#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled |
| 90 | + |
| 91 | +// Device count |
| 92 | +#define cudaGetDeviceCount hipGetDeviceCount |
0 commit comments