Skip to content

Commit b396742

Browse files
committed
[Metal] Add gather_qqmm
1 parent 596b8b3 commit b396742

9 files changed

Lines changed: 694 additions & 365 deletions

File tree

mlx/backend/cuda/quantized/qqmm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
125125
w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
126126
: std::make_tuple(
127127
ensure_contiguous(w_pre, encoder, s),
128-
ensure_contiguous(inputs[2], encoder, s));
128+
ensure_contiguous(inputs[base_size - 1], encoder, s));
129129

130130
// Reroute to qmm when: no support in cuBLAS, or doing GEMV.
131131
bool can_use_cublas =
@@ -242,7 +242,7 @@ void GatherQQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
242242
w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
243243
: std::make_tuple(
244244
ensure_contiguous(w_pre, encoder, s),
245-
ensure_contiguous(inputs[4], encoder, s));
245+
ensure_contiguous(inputs[base_size - 1], encoder, s));
246246

247247
// Quantize activation.
248248
array x = quantize_dequantize_input(

mlx/backend/metal/kernels/fp4.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
constant constexpr float F8E4M3_MAX = 448.0f;
4+
constant constexpr float F4E2M1_MAX = 6.0f;
5+
36
struct fp4_e2m1 {
47
fp4_e2m1(float x) {
58
if (metal::isnan(x)) {

0 commit comments

Comments
 (0)