@@ -21,7 +21,7 @@ std::tuple<array, array> quantize_input(
2121 QuantizationMode mode,
2222 int bits,
2323 int group_size,
24- std::optional<array> global_scale = std:: nullopt ) {
24+ std::optional<array> global_scale) {
2525 const array x = ensure_contiguous (input, encoder, s);
2626
2727 // Compute output shapes
@@ -54,7 +54,7 @@ std::tuple<array, array> quantize_input(
5454
5555array quantize_dequantize_input (
5656 const array& x_pre,
57- const std::optional<array>& global_scale_x ,
57+ const std::optional<array>& global_scale ,
5858 int bits,
5959 int group_size,
6060 cu::CommandEncoder& encoder,
@@ -69,7 +69,7 @@ array quantize_dequantize_input(
6969 if (!donate_x) {
7070 encoder.add_temporary (xhat);
7171 }
72- fp_quantize_dequantize (x, xhat, group_size, bits, global_scale_x , encoder, s);
72+ fp_quantize_dequantize (x, xhat, group_size, bits, global_scale , encoder, s);
7373 return xhat;
7474}
7575
@@ -99,24 +99,21 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
9999
100100 const array& x_pre = inputs[0 ];
101101 const array& w_pre = inputs[1 ];
102- const array& scales_w_pre = inputs[2 ];
103102
104103 out.set_data (cu::malloc_async (out.nbytes (), encoder));
105104
106105 // - 2 inputs: x, w (non-quantized w)
107106 // - 3 inputs: x, w, scales_w (quantized w)
108107 bool w_quantized = (w_pre.dtype () == uint32);
109108 int base_size = w_quantized ? 3 : 2 ;
110- assert (
111- inputs.size () == base_size ||
112- (mode_ == QuantizationMode::Nvfp4 && inputs.size () == base_size + 2 ));
113-
114109 // For nvfp4, global scales are optional but must be both present or both
115110 // absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
116111 bool has_global_scales =
117- mode_ == QuantizationMode::Nvfp4 && inputs.size () > base_size;
118- std::optional<array> global_scale_x = std::nullopt ;
119- std::optional<array> global_scale_w = std::nullopt ;
112+ mode_ == QuantizationMode::Nvfp4 && inputs.size () == base_size + 2 ;
113+ assert (inputs.size () == base_size || has_global_scales);
114+
115+ std::optional<array> global_scale_x;
116+ std::optional<array> global_scale_w;
120117 if (has_global_scales) {
121118 global_scale_x = inputs[inputs.size () - 2 ];
122119 global_scale_w = inputs[inputs.size () - 1 ];
@@ -128,12 +125,14 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
128125 w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
129126 : std::make_tuple (
130127 ensure_contiguous (w_pre, encoder, s),
131- ensure_contiguous (scales_w_pre , encoder, s));
128+ ensure_contiguous (inputs[ 2 ] , encoder, s));
132129
133130 // Reroute to qmm when: no support in cuBLAS, or doing GEMV.
131+ bool can_use_cublas =
132+ (mode_ == QuantizationMode::Nvfp4 || mode_ == QuantizationMode::Mxfp8) &&
133+ (device.compute_capability_major () >= 10 );
134134 int M = x_pre.shape (-2 );
135- bool use_qmm = (device.compute_capability_major () < 10 ) || (M == 1 );
136- use_qmm = true ;
135+ bool use_qmm = (!can_use_cublas) || (M == 1 );
137136
138137 if (use_qmm) {
139138 array x = quantize_dequantize_input (
@@ -207,4 +206,63 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
207206 scalars);
208207}
209208
209+ void GatherQQMM::eval_gpu (const std::vector<array>& inputs, array& out) {
210+ nvtx3::scoped_range r (" QQMatmul::eval_gpu" );
211+
212+ auto & s = stream ();
213+ auto & encoder = cu::get_command_encoder (s);
214+
215+ const array& x_pre = inputs[0 ];
216+ const array& w_pre = inputs[1 ];
217+ const array& lhs_indices = ensure_row_contiguous (inputs[2 ], encoder, s);
218+ const array& rhs_indices = ensure_row_contiguous (inputs[3 ], encoder, s);
219+
220+ out.set_data (cu::malloc_async (out.nbytes (), encoder));
221+
222+ // - 4 inputs: x, w, lhs_indices, rhs_indices (non-quantized w)
223+ // - 5 inputs: x, w, lhs_indices, rhs_indices, scales_w (quantized w)
224+ bool w_quantized = (w_pre.dtype () == uint32);
225+ int base_size = w_quantized ? 5 : 4 ;
226+ // For nvfp4, global scales are optional but must be both present or both
227+ // absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
228+ bool has_global_scales =
229+ mode_ == QuantizationMode::Nvfp4 && inputs.size () == base_size + 2 ;
230+ assert (inputs.size () == base_size || has_global_scales);
231+
232+ std::optional<array> global_scale_x;
233+ std::optional<array> global_scale_w;
234+ if (has_global_scales) {
235+ global_scale_x = inputs[inputs.size () - 2 ];
236+ global_scale_w = inputs[inputs.size () - 1 ];
237+ }
238+
239+ // Quantize weights.
240+ auto [w_q, scales_w] = !w_quantized
241+ ? quantize_input (
242+ w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
243+ : std::make_tuple (
244+ ensure_contiguous (w_pre, encoder, s),
245+ ensure_contiguous (inputs[4 ], encoder, s));
246+
247+ // Quantize activation.
248+ array x = quantize_dequantize_input (
249+ x_pre, global_scale_x, bits_, group_size_, encoder, s);
250+
251+ // Reroute to qmm.
252+ qmm_naive (
253+ x,
254+ w_q,
255+ scales_w,
256+ std::nullopt ,
257+ global_scale_w,
258+ lhs_indices,
259+ rhs_indices,
260+ out,
261+ true , // transpose
262+ bits_,
263+ group_size_,
264+ mode_,
265+ encoder);
266+ }
267+
210268} // namespace mlx::core
0 commit comments