@@ -105,6 +105,20 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8;
105105// blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are
106106// handled only by block 0 to avoid cross-block writes to the same slots.
107107// ---------------------------------------------------------------------------
108+
109+ // CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530,
110+ // so use the explicit intrinsics to avoid ambiguous implicit conversions.
111+ template <typename T>
112+ static __device__ inline T ar_add (T a, T b) {
113+ if constexpr (std::is_same_v<T, half>) {
114+ return __hadd (a, b);
115+ } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
116+ return __hadd (a, b);
117+ } else {
118+ return a + b;
119+ }
120+ }
121+
108122template <typename T_dst, typename T_wire>
109123static __global__ void ggml_cuda_ar_kernel (
110124 const T_dst * sendbuf,
@@ -184,13 +198,13 @@ static __global__ void ggml_cuda_ar_kernel(
184198 #pragma unroll
185199 for (int k = 0 ; k < ELEMS_PER_VEC ; ++k) {
186200 const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[off + k]);
187- recvbuf[off + k] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(wire[k]);
201+ recvbuf[off + k] = ar_add ( ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(wire[k]) );
188202 }
189203 }
190204 if (bid == 0 && tid < count - tail) {
191205 const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]);
192206 recvbuf[tail + tid] =
193- ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(host_other[tail + tid]);
207+ ar_add ( ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(host_other[tail + tid]) );
194208 }
195209 }
196210}
@@ -210,7 +224,7 @@ static __global__ void ggml_cuda_ar_add_kernel(
210224 const int nt = gridDim .x * blockDim .x ;
211225 for (int i = tid; i < count; i += nt) {
212226 const T_src d_low = ggml_cuda_cast<T_src>(dst[i]);
213- dst[i] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(src[i]);
227+ dst[i] = ar_add ( ggml_cuda_cast<T_dst>(d_low), ggml_cuda_cast<T_dst>(src[i]) );
214228 }
215229}
216230
0 commit comments