@@ -99,9 +99,9 @@ list(APPEND GGML_SOURCES_ROCM ${SRCS})
9999 q6_0
100100 q5_1
101101 q5_0
102+ q4_1
102103 turbo4_tcq
103104 turbo4_0
104- q4_1
105105 q4_0
106106 q3_1
107107 turbo3_tcq
@@ -113,6 +113,38 @@ list(APPEND GGML_SOURCES_ROCM ${SRCS})
113113 q2_0
114114 )
115115
116+ function (ggml_cuda_fa_kv_rank_rocm OUT TYPE )
117+ if (TYPE STREQUAL "f16" )
118+ set (${OUT} 0 PARENT_SCOPE )
119+ elseif (TYPE STREQUAL "bf16" )
120+ set (${OUT} 1 PARENT_SCOPE )
121+ elseif (TYPE STREQUAL "q8_0" )
122+ set (${OUT} 2 PARENT_SCOPE )
123+ elseif (TYPE STREQUAL "q6_1" )
124+ set (${OUT} 3 PARENT_SCOPE )
125+ elseif (TYPE STREQUAL "q6_0" )
126+ set (${OUT} 4 PARENT_SCOPE )
127+ elseif (TYPE STREQUAL "q5_1" )
128+ set (${OUT} 5 PARENT_SCOPE )
129+ elseif (TYPE STREQUAL "q5_0" )
130+ set (${OUT} 6 PARENT_SCOPE )
131+ elseif (TYPE STREQUAL "q4_1" OR TYPE STREQUAL "turbo4_tcq" OR TYPE STREQUAL "turbo4_0" )
132+ set (${OUT} 7 PARENT_SCOPE )
133+ elseif (TYPE STREQUAL "q4_0" )
134+ set (${OUT} 8 PARENT_SCOPE )
135+ elseif (TYPE STREQUAL "q3_1" OR TYPE STREQUAL "turbo3_tcq" OR TYPE STREQUAL "turbo3_0" )
136+ set (${OUT} 9 PARENT_SCOPE )
137+ elseif (TYPE STREQUAL "q3_0" )
138+ set (${OUT} 10 PARENT_SCOPE )
139+ elseif (TYPE STREQUAL "q2_1" OR TYPE STREQUAL "turbo2_tcq" OR TYPE STREQUAL "turbo2_0" )
140+ set (${OUT} 11 PARENT_SCOPE )
141+ elseif (TYPE STREQUAL "q2_0" )
142+ set (${OUT} 12 PARENT_SCOPE )
143+ else ()
144+ message (FATAL_ERROR "Unknown HIP/ROCm FA K/V rank type: ${TYPE} " )
145+ endif ()
146+ endfunction ()
147+
116148 list (LENGTH GGML_CUDA_FA_KV_TYPES_ORDERED GGML_CUDA_FA_KV_TYPES_LEN)
117149 math (EXPR GGML_CUDA_FA_KV_TYPES_LAST "${GGML_CUDA_FA_KV_TYPES_LEN} - 1" )
118150
@@ -124,15 +156,18 @@ list(APPEND GGML_SOURCES_ROCM ${SRCS})
124156 foreach (VI RANGE 0 ${GGML_CUDA_FA_KV_TYPES_LAST} )
125157 list (GET GGML_CUDA_FA_KV_TYPES_ORDERED ${VI} V_TYPE)
126158
127- if (KI LESS_EQUAL VI OR K_TYPE STREQUAL "f16" OR V_TYPE STREQUAL "f16" )
159+ ggml_cuda_fa_kv_rank_rocm (K_RANK "${K_TYPE} " )
160+ ggml_cuda_fa_kv_rank_rocm (V_RANK "${V_TYPE} " )
161+
162+ if (K_RANK LESS_EQUAL V_RANK OR K_TYPE STREQUAL "f16" OR V_TYPE STREQUAL "f16" )
128163 ggml_add_fattn_vec_pair_rocm (GGML_SOURCES_ROCM "${FA_VEC_PREFIX} " "${K_TYPE} " "${V_TYPE} " )
129164 math (EXPR GGML_CUDA_FA_HALF_PAIR_COUNT "${GGML_CUDA_FA_HALF_PAIR_COUNT} + 1" )
130165 endif ()
131166 endforeach ()
132167 endforeach ()
133168
134- if (NOT GGML_CUDA_FA_HALF_PAIR_COUNT EQUAL 208 )
135- message (FATAL_ERROR "GGML_CUDA_FA_HALF_QUANTS expected 208 FA vec pairs, got ${GGML_CUDA_FA_HALF_PAIR_COUNT} " )
169+ if (NOT GGML_CUDA_FA_HALF_PAIR_COUNT EQUAL 217 )
170+ message (FATAL_ERROR "GGML_CUDA_FA_HALF_QUANTS expected 217 FA vec pairs, got ${GGML_CUDA_FA_HALF_PAIR_COUNT} " )
136171 endif ()
137172
138173 add_compile_definitions (GGML_CUDA_FA_HALF_QUANTS )
0 commit comments