@@ -232,8 +232,8 @@ void fp_qmv(
232232 using T = cuda_type_t <MLX_GET_TYPE (type_tag)>;
233233 if constexpr (!std::is_same_v<T, double >) {
234234 dim3 block_dims{WARP_SIZE, rows_per_block};
235- uint B = out.size () / (M * N);
236- uint blocks_y = (N + rows_per_block - 1 ) / rows_per_block;
235+ uint32_t B = out.size () / (M * N);
236+ uint32_t blocks_y = (N + rows_per_block - 1 ) / rows_per_block;
237237 const uint32_t * mat_ptr = gpu_ptr<uint32_t >(mat);
238238 const T* vec_ptr = gpu_ptr<T>(vec);
239239 int n = 1 ;
@@ -249,16 +249,17 @@ void fp_qmv(
249249 }
250250 dispatch_1_2_4 (n, [&](auto n) {
251251 dispatch_bool (B > 1 , [&](auto batched) {
252- if (!batched ()) {
253- auto kernel = fp_qmv_single<T, rows_per_block, n (), 4 , 32 , true >;
252+ if (!batched.value ) {
253+ auto kernel =
254+ fp_qmv_single<T, rows_per_block, n.value , 4 , 32 , true >;
254255 if (bits == 8 ) {
255- kernel = fp_qmv_single<T, rows_per_block, n () , 8 , 32 , true >;
256+ kernel = fp_qmv_single<T, rows_per_block, n. value , 8 , 32 , true >;
256257 } else if (group_size == 16 ) {
257- kernel = fp_qmv_single<T, rows_per_block, n () , 4 , 16 , false >;
258+ kernel = fp_qmv_single<T, rows_per_block, n. value , 4 , 16 , false >;
258259 }
259260 encoder.add_kernel_node (
260261 kernel,
261- {static_cast <uint >(M), blocks_y},
262+ {static_cast <uint32_t >(M), blocks_y},
262263 block_dims,
263264 0 ,
264265 mat_ptr,
@@ -268,15 +269,16 @@ void fp_qmv(
268269 N,
269270 K);
270271 } else {
271- auto kernel = fp_qmv_batched<T, rows_per_block, n (), 4 , 32 , true >;
272+ auto kernel =
273+ fp_qmv_batched<T, rows_per_block, n.value , 4 , 32 , true >;
272274 if (bits == 8 ) {
273- kernel = fp_qmv_batched<T, rows_per_block, n () , 8 , 32 , true >;
275+ kernel = fp_qmv_batched<T, rows_per_block, n. value , 8 , 32 , true >;
274276 } else if (group_size == 16 ) {
275- kernel = fp_qmv_batched<T, rows_per_block, n () , 4 , 16 , false >;
277+ kernel = fp_qmv_batched<T, rows_per_block, n. value , 4 , 16 , false >;
276278 }
277279 encoder.add_kernel_node (
278280 kernel,
279- {static_cast <uint >(M), blocks_y, B},
281+ {static_cast <uint32_t >(M), blocks_y, B},
280282 block_dims,
281283 0 ,
282284 mat_ptr,
0 commit comments