Skip to content

Commit 0ff1115

Browse files
Lyxotzcbenz
andauthored
[CUDA] Implement BlockMaskedMM (ml-explore#3299)
Co-authored-by: Cheng <git@zcbenz.com>
1 parent df7f7db commit 0ff1115

7 files changed

Lines changed: 475 additions & 3 deletions

File tree

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
import argparse
4+
import time
5+
6+
import mlx.core as mx
7+
import numpy as np
8+
9+
MLX_DTYPES = {
10+
"float16": mx.float16,
11+
"bfloat16": mx.bfloat16,
12+
"float32": mx.float32,
13+
}
14+
15+
16+
def parse_cases(cases):
17+
parsed = []
18+
for spec in cases.split(","):
19+
parts = spec.split("x")
20+
m, n, k, bs = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3])
21+
sparsity = float(parts[4]) if len(parts) > 4 else 0.5
22+
parsed.append((m, n, k, bs, sparsity))
23+
return parsed
24+
25+
26+
def make_masks(m, n, k, block_size, sparsity, rng):
27+
"""Create block masks with given sparsity (fraction of blocks zeroed)."""
28+
tm = (m + block_size - 1) // block_size
29+
tn = (n + block_size - 1) // block_size
30+
tk = (k + block_size - 1) // block_size
31+
32+
lhs_mask = (rng.random((tm, tk)) >= sparsity).astype(np.bool_)
33+
rhs_mask = (rng.random((tk, tn)) >= sparsity).astype(np.bool_)
34+
out_mask = (rng.random((tm, tn)) >= sparsity).astype(np.bool_)
35+
return lhs_mask, rhs_mask, out_mask
36+
37+
38+
def mlx_naive_block_masked_mm(a, b, block_size, out_mask, lhs_mask, rhs_mask):
39+
"""MLX naive: expand masks and use regular matmul."""
40+
M, K = a.shape[-2], a.shape[-1]
41+
N = b.shape[-1]
42+
43+
def expand(mask, rows, cols):
44+
e = mx.repeat(mx.repeat(mask, block_size, axis=-2), block_size, axis=-1)
45+
return e[..., :rows, :cols]
46+
47+
a_masked = a * expand(lhs_mask, M, K)
48+
b_masked = b * expand(rhs_mask, K, N)
49+
c = a_masked @ b_masked
50+
c = c * expand(out_mask, M, N)
51+
return c
52+
53+
54+
def bench_mlx(fn, warmup, iters):
55+
for _ in range(warmup):
56+
y = fn()
57+
mx.eval(y)
58+
mx.synchronize()
59+
60+
start = time.perf_counter()
61+
for _ in range(iters):
62+
y = fn()
63+
mx.eval(y)
64+
mx.synchronize()
65+
return (time.perf_counter() - start) * 1e3 / iters
66+
67+
68+
def print_table(headers, rows):
69+
widths = [len(h) for h in headers]
70+
for row in rows:
71+
for i, cell in enumerate(row):
72+
widths[i] = max(widths[i], len(cell))
73+
74+
def fmt_row(row):
75+
return (
76+
"| "
77+
+ " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row))
78+
+ " |"
79+
)
80+
81+
sep = "|-" + "-|-".join("-" * w for w in widths) + "-|"
82+
print(fmt_row(headers))
83+
print(sep)
84+
for row in rows:
85+
print(fmt_row(row))
86+
87+
88+
def main():
89+
parser = argparse.ArgumentParser(
90+
description="Benchmark block_masked_mm vs naive expand+matmul"
91+
)
92+
parser.add_argument(
93+
"--cases",
94+
default=(
95+
"256x256x256x32x0.5,"
96+
"512x512x512x32x0.5,"
97+
"1024x1024x1024x32x0.5,"
98+
"1024x1024x1024x64x0.5,"
99+
"2048x2048x2048x64x0.5,"
100+
"256x256x256x32x0.0,"
101+
"1024x1024x1024x32x0.0,"
102+
"1024x1024x1024x32x0.9"
103+
),
104+
help="Comma-separated MxNxKxBSxSparsity list. Sparsity=fraction of blocks zeroed.",
105+
)
106+
parser.add_argument(
107+
"--dtype",
108+
default="float32",
109+
choices=["float16", "bfloat16", "float32"],
110+
)
111+
parser.add_argument("--warmup", type=int, default=10)
112+
parser.add_argument("--iters", type=int, default=50)
113+
parser.add_argument("--seed", type=int, default=42)
114+
parser.add_argument("--no-check", action="store_true")
115+
args = parser.parse_args()
116+
117+
mlx_dtype = MLX_DTYPES[args.dtype]
118+
119+
print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}")
120+
121+
headers = [
122+
"Case (MxNxKxBS)",
123+
"Sparsity",
124+
"MLX ms",
125+
"Naive ms",
126+
"Speedup",
127+
]
128+
if not args.no_check:
129+
headers.append("Max err")
130+
rows = []
131+
132+
cases = parse_cases(args.cases)
133+
for idx, (m, n, k, bs, sparsity) in enumerate(cases):
134+
rng = np.random.default_rng(args.seed + idx)
135+
a_np = rng.standard_normal((m, k)).astype(np.float32)
136+
b_np = rng.standard_normal((k, n)).astype(np.float32)
137+
lhs_mask_np, rhs_mask_np, out_mask_np = make_masks(m, n, k, bs, sparsity, rng)
138+
139+
a_mx = mx.array(a_np, dtype=mlx_dtype)
140+
b_mx = mx.array(b_np, dtype=mlx_dtype)
141+
lhs_mask_mx = mx.array(lhs_mask_np)
142+
rhs_mask_mx = mx.array(rhs_mask_np)
143+
out_mask_mx = mx.array(out_mask_np)
144+
mx.eval(a_mx, b_mx, lhs_mask_mx, rhs_mask_mx, out_mask_mx)
145+
146+
# Correctness check: block_masked_mm vs naive expand+matmul
147+
err_str = ""
148+
if not args.no_check:
149+
y_op = mx.block_masked_mm(
150+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
151+
)
152+
y_naive = mlx_naive_block_masked_mm(
153+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
154+
)
155+
mx.eval(y_op, y_naive)
156+
err = float(mx.max(mx.abs(y_op - y_naive)).item())
157+
err_str = f"{err:.2e}"
158+
159+
# Benchmark
160+
t_mlx = bench_mlx(
161+
lambda: mx.block_masked_mm(
162+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
163+
),
164+
args.warmup,
165+
args.iters,
166+
)
167+
t_naive = bench_mlx(
168+
lambda: mlx_naive_block_masked_mm(
169+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
170+
),
171+
args.warmup,
172+
args.iters,
173+
)
174+
speedup = f"{t_naive / t_mlx:.2f}x" if t_mlx > 0 else "-"
175+
176+
row = [
177+
f"{m}x{n}x{k}x{bs}",
178+
f"{sparsity:.0%}",
179+
f"{t_mlx:.3f}",
180+
f"{t_naive:.3f}",
181+
speedup,
182+
]
183+
if not args.no_check:
184+
row.append(err_str)
185+
rows.append(row)
186+
187+
print_table(headers, rows)
188+
if not args.no_check:
189+
print("err: max|block_masked_mm - naive_expand_matmul|")
190+
191+
192+
if __name__ == "__main__":
193+
main()

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ target_sources(
2828
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
2929
${CMAKE_CURRENT_SOURCE_DIR}/fft.cu
3030
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu
3132
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
3233
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/device/utils.cuh"
5+
#include "mlx/backend/cuda/gemms/block_mask.h"
6+
#include "mlx/backend/cuda/kernel_utils.cuh"
7+
#include "mlx/dtype_utils.h"
8+
9+
#include <cooperative_groups.h>
10+
11+
namespace mlx::core {
12+
13+
namespace cg = cooperative_groups;
14+
15+
namespace cu {
16+
17+
template <typename T, typename MaskT, bool SrcContiguous>
18+
__global__ void block_mask_copy_kernel(
19+
const T* src,
20+
T* dst,
21+
int block_size,
22+
int64_t rows,
23+
int64_t cols,
24+
const __grid_constant__ Shape src_shape,
25+
const __grid_constant__ Strides src_strides,
26+
int src_ndim,
27+
MaskT* mask,
28+
const __grid_constant__ Shape mask_shape,
29+
const __grid_constant__ Strides mask_strides,
30+
int mask_ndim,
31+
int64_t mask_row_stride,
32+
int64_t mask_col_stride,
33+
int64_t mask_mat_size,
34+
int64_t batch_count) {
35+
int64_t mat_size = rows * cols;
36+
int64_t idx = cg::this_grid().thread_rank();
37+
if (idx >= batch_count * mat_size)
38+
return;
39+
40+
int64_t batch = idx / mat_size;
41+
int64_t within = idx % mat_size;
42+
int64_t mask_batch_offset = elem_to_loc(
43+
batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim);
44+
MaskT mask_val = mask
45+
[mask_batch_offset + (within / cols) / block_size * mask_row_stride +
46+
(within % cols) / block_size * mask_col_stride];
47+
48+
int64_t src_offset;
49+
if constexpr (SrcContiguous) {
50+
src_offset = idx;
51+
} else {
52+
src_offset = elem_to_loc(
53+
batch * mat_size + within,
54+
src_shape.data(),
55+
src_strides.data(),
56+
src_ndim);
57+
}
58+
59+
if constexpr (std::is_same_v<MaskT, bool>) {
60+
dst[idx] = mask_val ? src[src_offset] : T(0);
61+
} else {
62+
dst[idx] = src[src_offset] * T(mask_val);
63+
}
64+
}
65+
66+
} // namespace cu
67+
68+
namespace {
69+
70+
template <typename T, typename F>
71+
void dispatch_mask_type(Dtype mask_dtype, F&& f) {
72+
if (mask_dtype == bool_) {
73+
f.template operator()<bool>();
74+
} else {
75+
f.template operator()<T>();
76+
}
77+
}
78+
79+
void block_mask_copy(
80+
cu::CommandEncoder& encoder,
81+
const array& src,
82+
array& dst,
83+
const array& mask,
84+
int block_size,
85+
int64_t rows,
86+
int64_t cols,
87+
bool src_contiguous,
88+
int64_t batch_count) {
89+
int mask_ndim = mask.ndim();
90+
int64_t mask_row_str = mask.strides()[mask_ndim - 2];
91+
int64_t mask_col_str = mask.strides()[mask_ndim - 1];
92+
int64_t mask_mat_size =
93+
int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1];
94+
95+
auto [num_blocks, block_dims] = get_launch_args(src, src.size() > INT32_MAX);
96+
97+
dispatch_float_types(src.dtype(), "block_mask_copy", [&](auto type_tag) {
98+
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
99+
100+
dispatch_mask_type<T>(mask.dtype(), [&]<typename MaskT>() {
101+
dispatch_bool(src_contiguous, [&](auto contiguous_tag) {
102+
constexpr bool Contiguous = decltype(contiguous_tag)::value;
103+
encoder.add_kernel_node(
104+
cu::block_mask_copy_kernel<T, MaskT, Contiguous>,
105+
num_blocks,
106+
block_dims,
107+
gpu_ptr<T>(src),
108+
gpu_ptr<T>(dst),
109+
block_size,
110+
rows,
111+
cols,
112+
const_param(src.shape()),
113+
const_param(src.strides()),
114+
src.ndim(),
115+
gpu_ptr<MaskT>(mask),
116+
const_param(mask.shape()),
117+
const_param(mask.strides()),
118+
mask_ndim,
119+
mask_row_str,
120+
mask_col_str,
121+
mask_mat_size,
122+
batch_count);
123+
});
124+
});
125+
});
126+
}
127+
128+
} // namespace
129+
130+
void apply_block_mask(
131+
cu::CommandEncoder& encoder,
132+
array& data,
133+
const array& mask,
134+
int block_size,
135+
int64_t rows,
136+
int64_t cols,
137+
int64_t batch_count) {
138+
encoder.set_input_array(mask);
139+
encoder.set_output_array(data);
140+
141+
// Use block_mask_copy in-place (src == dst) with SrcContiguous=true.
142+
block_mask_copy(
143+
encoder, data, data, mask, block_size, rows, cols, true, batch_count);
144+
}
145+
146+
array copy_with_block_mask(
147+
cu::CommandEncoder& encoder,
148+
const array& src,
149+
const array& mask,
150+
int block_size,
151+
int64_t rows,
152+
int64_t cols,
153+
int64_t batch_count) {
154+
array dst(src.shape(), src.dtype(), nullptr, {});
155+
dst.set_data(cu::malloc_async(dst.nbytes(), encoder));
156+
encoder.add_temporary(dst);
157+
158+
encoder.set_input_array(src);
159+
encoder.set_input_array(mask);
160+
encoder.set_output_array(dst);
161+
162+
block_mask_copy(
163+
encoder,
164+
src,
165+
dst,
166+
mask,
167+
block_size,
168+
rows,
169+
cols,
170+
src.flags().row_contiguous,
171+
batch_count);
172+
173+
return dst;
174+
}
175+
176+
} // namespace mlx::core

0 commit comments

Comments
 (0)