Skip to content

Commit 7f3f843

Browse files
authored
Fix for issue ggml-org#22974. Cast intermediate results to float before adding and casting the result to the destination type. Avoids half+half operator ambiguity. (ggml-org#22994)
1 parent ec562eb commit 7f3f843

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

ggml/src/ggml-cuda/allreduce.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,15 @@ static __global__ void ggml_cuda_ar_kernel(
184184
#pragma unroll
185185
for (int k = 0; k < ELEMS_PER_VEC; ++k) {
186186
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[off + k]);
187-
recvbuf[off + k] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(wire[k]);
187+
recvbuf[off + k] = ggml_cuda_cast<T_dst>(
188+
ggml_cuda_cast<float>(d_low) + ggml_cuda_cast<float>(wire[k]));
188189
}
189190
}
190191
if (bid == 0 && tid < count - tail) {
191192
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]);
192-
recvbuf[tail + tid] =
193-
ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(host_other[tail + tid]);
193+
recvbuf[tail + tid] = ggml_cuda_cast<T_dst>(
194+
ggml_cuda_cast<float>(d_low) +
195+
ggml_cuda_cast<float>(host_other[tail + tid]));
194196
}
195197
}
196198
}
@@ -210,7 +212,8 @@ static __global__ void ggml_cuda_ar_add_kernel(
210212
const int nt = gridDim.x * blockDim.x;
211213
for (int i = tid; i < count; i += nt) {
212214
const T_src d_low = ggml_cuda_cast<T_src>(dst[i]);
213-
dst[i] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(src[i]);
215+
dst[i] = ggml_cuda_cast<T_dst>(
216+
ggml_cuda_cast<float>(d_low) + ggml_cuda_cast<float>(src[i]));
214217
}
215218
}
216219

0 commit comments

Comments
 (0)