Skip to content

Commit f8977c7

Browse files
committed
fix(hip): resolve undefined references for half quants vec dispatch
1 parent df8933c commit f8977c7

1 file changed

Lines changed: 39 additions & 4 deletions

File tree

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ list(APPEND GGML_SOURCES_ROCM ${SRCS})
9999
q6_0
100100
q5_1
101101
q5_0
102+
q4_1
102103
turbo4_tcq
103104
turbo4_0
104-
q4_1
105105
q4_0
106106
q3_1
107107
turbo3_tcq
@@ -113,6 +113,38 @@ list(APPEND GGML_SOURCES_ROCM ${SRCS})
113113
q2_0
114114
)
115115

116+
function(ggml_cuda_fa_kv_rank_rocm OUT TYPE)
117+
if (TYPE STREQUAL "f16")
118+
set(${OUT} 0 PARENT_SCOPE)
119+
elseif (TYPE STREQUAL "bf16")
120+
set(${OUT} 1 PARENT_SCOPE)
121+
elseif (TYPE STREQUAL "q8_0")
122+
set(${OUT} 2 PARENT_SCOPE)
123+
elseif (TYPE STREQUAL "q6_1")
124+
set(${OUT} 3 PARENT_SCOPE)
125+
elseif (TYPE STREQUAL "q6_0")
126+
set(${OUT} 4 PARENT_SCOPE)
127+
elseif (TYPE STREQUAL "q5_1")
128+
set(${OUT} 5 PARENT_SCOPE)
129+
elseif (TYPE STREQUAL "q5_0")
130+
set(${OUT} 6 PARENT_SCOPE)
131+
elseif (TYPE STREQUAL "q4_1" OR TYPE STREQUAL "turbo4_tcq" OR TYPE STREQUAL "turbo4_0")
132+
set(${OUT} 7 PARENT_SCOPE)
133+
elseif (TYPE STREQUAL "q4_0")
134+
set(${OUT} 8 PARENT_SCOPE)
135+
elseif (TYPE STREQUAL "q3_1" OR TYPE STREQUAL "turbo3_tcq" OR TYPE STREQUAL "turbo3_0")
136+
set(${OUT} 9 PARENT_SCOPE)
137+
elseif (TYPE STREQUAL "q3_0")
138+
set(${OUT} 10 PARENT_SCOPE)
139+
elseif (TYPE STREQUAL "q2_1" OR TYPE STREQUAL "turbo2_tcq" OR TYPE STREQUAL "turbo2_0")
140+
set(${OUT} 11 PARENT_SCOPE)
141+
elseif (TYPE STREQUAL "q2_0")
142+
set(${OUT} 12 PARENT_SCOPE)
143+
else()
144+
message(FATAL_ERROR "Unknown HIP/ROCm FA K/V rank type: ${TYPE}")
145+
endif()
146+
endfunction()
147+
116148
list(LENGTH GGML_CUDA_FA_KV_TYPES_ORDERED GGML_CUDA_FA_KV_TYPES_LEN)
117149
math(EXPR GGML_CUDA_FA_KV_TYPES_LAST "${GGML_CUDA_FA_KV_TYPES_LEN} - 1")
118150

@@ -124,15 +156,18 @@ list(APPEND GGML_SOURCES_ROCM ${SRCS})
124156
foreach(VI RANGE 0 ${GGML_CUDA_FA_KV_TYPES_LAST})
125157
list(GET GGML_CUDA_FA_KV_TYPES_ORDERED ${VI} V_TYPE)
126158

127-
if (KI LESS_EQUAL VI OR K_TYPE STREQUAL "f16" OR V_TYPE STREQUAL "f16")
159+
ggml_cuda_fa_kv_rank_rocm(K_RANK "${K_TYPE}")
160+
ggml_cuda_fa_kv_rank_rocm(V_RANK "${V_TYPE}")
161+
162+
if (K_RANK LESS_EQUAL V_RANK OR K_TYPE STREQUAL "f16" OR V_TYPE STREQUAL "f16")
128163
ggml_add_fattn_vec_pair_rocm(GGML_SOURCES_ROCM "${FA_VEC_PREFIX}" "${K_TYPE}" "${V_TYPE}")
129164
math(EXPR GGML_CUDA_FA_HALF_PAIR_COUNT "${GGML_CUDA_FA_HALF_PAIR_COUNT} + 1")
130165
endif()
131166
endforeach()
132167
endforeach()
133168

134-
if (NOT GGML_CUDA_FA_HALF_PAIR_COUNT EQUAL 208)
135-
message(FATAL_ERROR "GGML_CUDA_FA_HALF_QUANTS expected 208 FA vec pairs, got ${GGML_CUDA_FA_HALF_PAIR_COUNT}")
169+
if (NOT GGML_CUDA_FA_HALF_PAIR_COUNT EQUAL 217)
170+
message(FATAL_ERROR "GGML_CUDA_FA_HALF_QUANTS expected 217 FA vec pairs, got ${GGML_CUDA_FA_HALF_PAIR_COUNT}")
136171
endif()
137172

138173
add_compile_definitions(GGML_CUDA_FA_HALF_QUANTS)

0 commit comments

Comments
 (0)