Skip to content

Commit ca36a2d

Browse files
committed
[CUDA] Add gather_qqmm
1 parent b5643ca commit ca36a2d

11 files changed

Lines changed: 399 additions & 98 deletions

File tree

mlx/backend/cpu/quantized.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,4 +1359,8 @@ void QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
13591359
}
13601360
}
13611361

1362+
void GatherQQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
1363+
throw std::runtime_error("[GatherQQMM] NYI");
1364+
}
1365+
13621366
} // namespace mlx::core

mlx/backend/cuda/quantized/qqmm.cpp

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::tuple<array, array> quantize_input(
2121
QuantizationMode mode,
2222
int bits,
2323
int group_size,
24-
std::optional<array> global_scale = std::nullopt) {
24+
std::optional<array> global_scale) {
2525
const array x = ensure_contiguous(input, encoder, s);
2626

2727
// Compute output shapes
@@ -54,7 +54,7 @@ std::tuple<array, array> quantize_input(
5454

5555
array quantize_dequantize_input(
5656
const array& x_pre,
57-
const std::optional<array>& global_scale_x,
57+
const std::optional<array>& global_scale,
5858
int bits,
5959
int group_size,
6060
cu::CommandEncoder& encoder,
@@ -69,7 +69,7 @@ array quantize_dequantize_input(
6969
if (!donate_x) {
7070
encoder.add_temporary(xhat);
7171
}
72-
fp_quantize_dequantize(x, xhat, group_size, bits, global_scale_x, encoder, s);
72+
fp_quantize_dequantize(x, xhat, group_size, bits, global_scale, encoder, s);
7373
return xhat;
7474
}
7575

@@ -99,24 +99,21 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
9999

100100
const array& x_pre = inputs[0];
101101
const array& w_pre = inputs[1];
102-
const array& scales_w_pre = inputs[2];
103102

104103
out.set_data(cu::malloc_async(out.nbytes(), encoder));
105104

106105
// - 2 inputs: x, w (non-quantized w)
107106
// - 3 inputs: x, w, scales_w (quantized w)
108107
bool w_quantized = (w_pre.dtype() == uint32);
109108
int base_size = w_quantized ? 3 : 2;
110-
assert(
111-
inputs.size() == base_size ||
112-
(mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));
113-
114109
// For nvfp4, global scales are optional but must be both present or both
115110
// absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
116111
bool has_global_scales =
117-
mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
118-
std::optional<array> global_scale_x = std::nullopt;
119-
std::optional<array> global_scale_w = std::nullopt;
112+
mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2;
113+
assert(inputs.size() == base_size || has_global_scales);
114+
115+
std::optional<array> global_scale_x;
116+
std::optional<array> global_scale_w;
120117
if (has_global_scales) {
121118
global_scale_x = inputs[inputs.size() - 2];
122119
global_scale_w = inputs[inputs.size() - 1];
@@ -128,12 +125,14 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
128125
w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
129126
: std::make_tuple(
130127
ensure_contiguous(w_pre, encoder, s),
131-
ensure_contiguous(scales_w_pre, encoder, s));
128+
ensure_contiguous(inputs[2], encoder, s));
132129

133130
// Reroute to qmm when: no support in cuBLAS, or doing GEMV.
131+
bool can_use_cublas =
132+
(mode_ == QuantizationMode::Nvfp4 || mode_ == QuantizationMode::Mxfp8) &&
133+
(device.compute_capability_major() >= 10);
134134
int M = x_pre.shape(-2);
135-
bool use_qmm = (device.compute_capability_major() < 10) || (M == 1);
136-
use_qmm = true;
135+
bool use_qmm = (!can_use_cublas) || (M == 1);
137136

138137
if (use_qmm) {
139138
array x = quantize_dequantize_input(
@@ -207,4 +206,63 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
207206
scalars);
208207
}
209208

209+
void GatherQQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
210+
nvtx3::scoped_range r("QQMatmul::eval_gpu");
211+
212+
auto& s = stream();
213+
auto& encoder = cu::get_command_encoder(s);
214+
215+
const array& x_pre = inputs[0];
216+
const array& w_pre = inputs[1];
217+
const array& lhs_indices = ensure_row_contiguous(inputs[2], encoder, s);
218+
const array& rhs_indices = ensure_row_contiguous(inputs[3], encoder, s);
219+
220+
out.set_data(cu::malloc_async(out.nbytes(), encoder));
221+
222+
// - 4 inputs: x, w, lhs_indices, rhs_indices (non-quantized w)
223+
// - 5 inputs: x, w, lhs_indices, rhs_indices, scales_w (quantized w)
224+
bool w_quantized = (w_pre.dtype() == uint32);
225+
int base_size = w_quantized ? 5 : 4;
226+
// For nvfp4, global scales are optional but must be both present or both
227+
// absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
228+
bool has_global_scales =
229+
mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2;
230+
assert(inputs.size() == base_size || has_global_scales);
231+
232+
std::optional<array> global_scale_x;
233+
std::optional<array> global_scale_w;
234+
if (has_global_scales) {
235+
global_scale_x = inputs[inputs.size() - 2];
236+
global_scale_w = inputs[inputs.size() - 1];
237+
}
238+
239+
// Quantize weights.
240+
auto [w_q, scales_w] = !w_quantized
241+
? quantize_input(
242+
w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
243+
: std::make_tuple(
244+
ensure_contiguous(w_pre, encoder, s),
245+
ensure_contiguous(inputs[4], encoder, s));
246+
247+
// Quantize activation.
248+
array x = quantize_dequantize_input(
249+
x_pre, global_scale_x, bits_, group_size_, encoder, s);
250+
251+
// Reroute to qmm.
252+
qmm_naive(
253+
x,
254+
w_q,
255+
scales_w,
256+
std::nullopt,
257+
global_scale_w,
258+
lhs_indices,
259+
rhs_indices,
260+
out,
261+
true, // transpose
262+
bits_,
263+
group_size_,
264+
mode_,
265+
encoder);
266+
}
267+
210268
} // namespace mlx::core

mlx/backend/metal/quantized.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,10 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
16671667
}
16681668
}
16691669

1670+
void GatherQQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
1671+
throw std::runtime_error("[GatherQQMM] NYI");
1672+
}
1673+
16701674
void fast::Quantize::eval_gpu(
16711675
const std::vector<array>& inputs,
16721676
std::vector<array>& outputs) {

mlx/backend/no_cpu/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ NO_CPU(Gather)
7171
NO_CPU(GatherAxis)
7272
NO_CPU(GatherMM)
7373
NO_CPU(GatherQMM)
74+
NO_CPU(GatherQQMM)
7475
NO_CPU(Greater)
7576
NO_CPU(GreaterEqual)
7677
NO_CPU(Hadamard)

mlx/backend/no_gpu/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ NO_GPU(Gather)
9898
NO_GPU(GatherAxis)
9999
NO_GPU(GatherMM)
100100
NO_GPU(GatherQMM)
101+
NO_GPU(GatherQQMM)
101102
NO_GPU(Greater)
102103
NO_GPU(GreaterEqual)
103104
NO_GPU(Hadamard)

0 commit comments

Comments
 (0)