Skip to content

Commit bef474a

Browse files
TimDettmersclaude
andcommitted
Fix HIP build errors in unified kernel files
- Add hip/hip_bfloat16.h include to compat.cuh (bnb_bfloat16 type alias requires hip_bfloat16 to be defined) - Add __syncwarp() no-op macro for HIP (AMD warps are always in lockstep) - Add hipblas version check (#if hipblasVersionMajor >= 3) for GemmEx calls (ROCm 6.1 ships hipblas v2 which uses HIPBLAS_R_* not HIPBLAS_COMPUTE_*) - Fix include in ops.cuh: common.h -> common.cuh (BNB_WARP_SIZE visibility) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d7f3e15 commit bef474a

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

csrc/compat.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#if BNB_HIP
2727

2828
#include <hip/hip_fp16.h>
29+
#include <hip/hip_bfloat16.h>
2930
#include <hip/hip_math_constants.h>
3031
#include <hip/hip_runtime.h>
3132
#include <hipblas/hipblas.h>
@@ -85,6 +86,17 @@ using bnb_error_t = cudaError_t;
8586
// Keep backward compat for existing code during migration
8687
#define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value)
8788

89+
// ============================================================================
90+
// Warp synchronization
91+
//
92+
// HIP warps are always in lockstep (no independent thread scheduling),
93+
// so __syncwarp() is a no-op. CUDA needs it for warp convergence.
94+
// ============================================================================
95+
96+
#if BNB_HIP
97+
#define __syncwarp() do {} while(0)
98+
#endif
99+
88100
// ============================================================================
89101
// BFloat16 type alias
90102
//

csrc/ops.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,18 @@ void gemmex(
270270
#if BNB_HIP
271271
hipblasStatus_t status;
272272

273+
#if hipblasVersionMajor >= 3
273274
status = hipblasGemmEx(
274275
context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k,
275276
alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, C, HIP_R_32I, ldc, HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT
276277
);
278+
#else
279+
status = hipblasGemmEx(
280+
context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k,
281+
alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, C, HIPBLAS_R_32I, ldc, HIPBLAS_R_32I,
282+
HIPBLAS_GEMM_DEFAULT
283+
);
284+
#endif
277285

278286
if (status != HIPBLAS_STATUS_SUCCESS) {
279287
std::cout << "HIPBLAS ERROR: Status " << status << std::endl;
@@ -304,11 +312,19 @@ void strided_gemmex(
304312
#if BNB_HIP
305313
hipblasStatus_t status;
306314

315+
#if hipblasVersionMajor >= 3
307316
status = hipblasGemmStridedBatchedEx(
308317
context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k,
309318
alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, C, HIP_R_32I,
310319
ldc, (long long int)strideC, batchCount, HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT
311320
);
321+
#else
322+
status = hipblasGemmStridedBatchedEx(
323+
context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k,
324+
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, C,
325+
HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT
326+
);
327+
#endif
312328

313329
if (status != HIPBLAS_STATUS_SUCCESS) {
314330
std::cout << "HIPBLAS ERROR: Status " << status << std::endl;

csrc/ops.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <vector>
1515

1616
#include "compat.cuh"
17-
#include <common.h>
17+
#include "common.cuh"
1818

1919
// ============================================================================
2020
// Error checking helpers

0 commit comments

Comments
 (0)