Skip to content

Commit 06b6a53

Browse files
committed
ggml-cuda : add ar_add() to avoid ambiguous operator+ for half/bfloat16 in CUDA 11.8
1 parent a2839b4 commit 06b6a53

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

ggml/src/ggml-cuda/allreduce.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8;
105105
// blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are
106106
// handled only by block 0 to avoid cross-block writes to the same slots.
107107
// ---------------------------------------------------------------------------
108+
109+
// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530,
110+
// so use the explicit intrinsics to avoid ambiguous implicit conversions.
111+
template<typename T>
112+
static __device__ inline T ar_add(T a, T b) {
113+
if constexpr (std::is_same_v<T, half>) {
114+
return __hadd(a, b);
115+
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
116+
return __hadd(a, b);
117+
} else {
118+
return a + b;
119+
}
120+
}
121+
108122
template <typename T_dst, typename T_wire>
109123
static __global__ void ggml_cuda_ar_kernel(
110124
const T_dst * sendbuf,
@@ -184,13 +198,13 @@ static __global__ void ggml_cuda_ar_kernel(
184198
#pragma unroll
185199
for (int k = 0; k < ELEMS_PER_VEC; ++k) {
186200
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[off + k]);
187-
recvbuf[off + k] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(wire[k]);
201+
recvbuf[off + k] = ar_add(ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(wire[k]));
188202
}
189203
}
190204
if (bid == 0 && tid < count - tail) {
191205
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]);
192206
recvbuf[tail + tid] =
193-
ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(host_other[tail + tid]);
207+
ar_add(ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(host_other[tail + tid]));
194208
}
195209
}
196210
}
@@ -210,7 +224,7 @@ static __global__ void ggml_cuda_ar_add_kernel(
210224
const int nt = gridDim.x * blockDim.x;
211225
for (int i = tid; i < count; i += nt) {
212226
const T_src d_low = ggml_cuda_cast<T_src>(dst[i]);
213-
dst[i] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(src[i]);
227+
dst[i] = ar_add(ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(src[i]));
214228
}
215229
}
216230

0 commit comments

Comments
 (0)