Skip to content

Commit 9f1bcc9

Browse files
committed
fix(kvarn): cast HIP kernel pointer for dynamic shmem attribute
1 parent 7eff066 commit 9f1bcc9

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

ggml/src/ggml-cuda/kvarn.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,16 @@ void ggml_cuda_op_kvarn_store(ggml_backend_cuda_context & ctx, ggml_tensor * dst
638638
const size_t smpbo = ggml_cuda_info().devices[ctx.device].smpbo;
639639

640640
if (!force_low_shmem && smpbo >= KVAR_N_SHARED_BYTES) {
641-
#if !defined(GGML_USE_MUSA)
642-
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<const void*>(kvarn_store_kernel_hishmem), cudaFuncAttributeMaxDynamicSharedMemorySize, KVAR_N_SHARED_BYTES));
641+
#if defined(GGML_USE_HIP)
642+
CUDA_CHECK(hipFuncSetAttribute(
643+
reinterpret_cast<const void *>(&kvarn_store_kernel_hishmem),
644+
hipFuncAttributeMaxDynamicSharedMemorySize,
645+
KVAR_N_SHARED_BYTES));
646+
#elif !defined(GGML_USE_MUSA)
647+
CUDA_CHECK(cudaFuncSetAttribute(
648+
kvarn_store_kernel_hishmem,
649+
cudaFuncAttributeMaxDynamicSharedMemorySize,
650+
KVAR_N_SHARED_BYTES));
643651
#endif
644652
kvarn_store_kernel_hishmem<<<current->ne[1], KVAR_N_DIM, KVAR_N_SHARED_BYTES, ctx.stream()>>>(
645653
(const float *) current->data,

0 commit comments

Comments
 (0)