Skip to content

Commit b48adae

Browse files
committed
Merge upstream main into rocm-support
2 parents c6d9a92 + 0ff1115 commit b48adae

278 files changed

Lines changed: 16488 additions & 5561 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/actions/build-cuda-release/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ runs:
2020
run: |
2121
pip install auditwheel build patchelf setuptools
2222
python setup.py clean --all
23-
MLX_BUILD_STAGE=2 python -m build -w
23+
MLX_DISABLE_SM90A_KERNELS=1 MLX_BUILD_STAGE=2 python -m build -w
2424
2525
auditwheel repair dist/mlx_cuda*.whl \
2626
--plat manylinux_2_35_${{ inputs.arch }} \

.github/actions/build-linux/action.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ inputs:
99
runs:
1010
using: "composite"
1111
steps:
12+
1213
- name: Install Python package
1314
id: python_build
1415
shell: sh
@@ -20,7 +21,7 @@ runs:
2021
run: |
2122
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
2223
# There is no GPU in arm64 runner, use a common arch.
23-
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
24+
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=80"
2425
# Can not build tests and stubs when the built executables can not run.
2526
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF -DMLX_BUILD_PYTHON_STUBS=OFF"
2627
fi

.github/actions/build-macos/action.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ runs:
1212
run: |
1313
pip install --upgrade pip
1414
pip install cmake setuptools typing_extensions
15-
pip install -e . -v
15+
pip install -e ".[dev]" -v
1616
1717
- name: Install tests dependencies
1818
shell: bash -l {0}
1919
run: |
20-
pip install numpy torch tensorflow
20+
pip install tensorflow
2121
2222
- name: Run Python tests
2323
shell: bash -l {0}

.github/actions/test-linux/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,5 @@ runs:
6565
DEVICE: gpu
6666
run: |
6767
echo "::group::CPP tests - GPU"
68-
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
68+
./build/tests/tests -sfe="*linalg_tests.cpp"
6969
echo "::endgroup::"

.github/workflows/nightly.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ jobs:
2323
build-backend: ${{ matrix.python-version == '3.10' }}
2424
arch: "x86_64"
2525
- name: Upload mlx artifacts
26-
uses: actions/upload-artifact@v6
26+
uses: actions/upload-artifact@v7
2727
with:
2828
name: linux-wheels-${{ matrix.python_version }}
2929
path: wheelhouse/mlx-*.whl
3030
retention-days: 7
3131
- name: Upload mlx-cpu artifacts
3232
if: matrix.python_version == '3.10'
33-
uses: actions/upload-artifact@v6
33+
uses: actions/upload-artifact@v7
3434
with:
3535
name: mlx-cpu
3636
path: wheelhouse/mlx_cpu-*.whl
@@ -97,7 +97,7 @@ jobs:
9797
toolkit: 'cuda-12.9'
9898
arch: 'x86_64'
9999
- name: Upload artifacts
100-
uses: actions/upload-artifact@v6
100+
uses: actions/upload-artifact@v7
101101
with:
102102
name: mlx-cuda
103103
path: wheelhouse/mlx_cuda_*.whl

.github/workflows/release.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ jobs:
6464
build-backend: ${{ matrix.python_version == '3.10' }}
6565
arch: ${{ matrix.arch }}
6666
- name: Upload MLX artifacts
67-
uses: actions/upload-artifact@v6
67+
uses: actions/upload-artifact@v7
6868
with:
6969
overwrite: true
7070
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
7171
path: wheelhouse/mlx-*.whl
7272
if-no-files-found: error
7373
- name: Upload CPU artifacts
7474
if: matrix.python_version == '3.10'
75-
uses: actions/upload-artifact@v6
75+
uses: actions/upload-artifact@v7
7676
with:
7777
overwrite: true
7878
name: mlx-cpu-${{ matrix.arch }}
@@ -116,15 +116,15 @@ jobs:
116116
macos-target: 26.0
117117
build-backend: ${{ matrix.python-version == '3.10' }}
118118
- name: Upload MLX artifacts
119-
uses: actions/upload-artifact@v6
119+
uses: actions/upload-artifact@v7
120120
with:
121121
overwrite: true
122122
name: mac-wheels-${{ matrix.python-version }}
123123
path: dist/mlx-*.whl
124124
if-no-files-found: error
125125
- name: Upload Metal artifacts
126126
if: matrix.python-version == '3.10'
127-
uses: actions/upload-artifact@v6
127+
uses: actions/upload-artifact@v7
128128
with:
129129
overwrite: true
130130
name: mlx-metal
@@ -152,7 +152,7 @@ jobs:
152152
with:
153153
arch: ${{ matrix.arch }}
154154
- name: Upload artifacts
155-
uses: actions/upload-artifact@v6
155+
uses: actions/upload-artifact@v7
156156
with:
157157
overwrite: true
158158
name: mlx-${{ matrix.toolkit }}-${{ matrix.arch }}
@@ -169,12 +169,12 @@ jobs:
169169
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
170170
url: https://pypi.org/p/mlx
171171
steps:
172-
- uses: actions/download-artifact@v7
172+
- uses: actions/download-artifact@v8
173173
with:
174174
pattern: linux-wheels-*
175175
merge-multiple: true
176176
path: dist
177-
- uses: actions/download-artifact@v7
177+
- uses: actions/download-artifact@v8
178178
with:
179179
pattern: mac-wheels-*
180180
merge-multiple: true
@@ -197,7 +197,7 @@ jobs:
197197
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
198198
url: https://pypi.org/p/mlx-cuda
199199
steps:
200-
- uses: actions/download-artifact@v7
200+
- uses: actions/download-artifact@v8
201201
with:
202202
pattern: mlx-cuda-*
203203
merge-multiple: true
@@ -220,7 +220,7 @@ jobs:
220220
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
221221
url: https://pypi.org/p/mlx-cpu
222222
steps:
223-
- uses: actions/download-artifact@v7
223+
- uses: actions/download-artifact@v8
224224
with:
225225
pattern: mlx-cpu-*
226226
merge-multiple: true
@@ -243,7 +243,7 @@ jobs:
243243
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
244244
url: https://pypi.org/p/mlx-metal
245245
steps:
246-
- uses: actions/download-artifact@v7
246+
- uses: actions/download-artifact@v8
247247
with:
248248
name: mlx-metal
249249
path: dist

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ if(MLX_BUILD_CUDA)
157157
enable_language(CUDA)
158158
find_package(CUDAToolkit REQUIRED)
159159
find_package(CUDNN REQUIRED)
160+
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "13.1" AND CUDAToolkit_VERSION
161+
VERSION_LESS "13.2")
162+
message(FATAL_ERROR "CUDA Toolkit 13.1 is not supported.")
163+
endif()
160164
endif()
161165

162166
if(MLX_BUILD_ROCM)
@@ -369,7 +373,7 @@ else()
369373
FetchContent_Declare(
370374
fmt
371375
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
372-
GIT_TAG 10.2.1
376+
GIT_TAG 12.1.0
373377
EXCLUDE_FROM_ALL)
374378
FetchContent_MakeAvailable(fmt)
375379
endif()
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
import argparse
4+
import time
5+
6+
import mlx.core as mx
7+
import numpy as np
8+
9+
MLX_DTYPES = {
10+
"float16": mx.float16,
11+
"bfloat16": mx.bfloat16,
12+
"float32": mx.float32,
13+
}
14+
15+
16+
def parse_cases(cases):
17+
parsed = []
18+
for spec in cases.split(","):
19+
parts = spec.split("x")
20+
m, n, k, bs = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3])
21+
sparsity = float(parts[4]) if len(parts) > 4 else 0.5
22+
parsed.append((m, n, k, bs, sparsity))
23+
return parsed
24+
25+
26+
def make_masks(m, n, k, block_size, sparsity, rng):
27+
"""Create block masks with given sparsity (fraction of blocks zeroed)."""
28+
tm = (m + block_size - 1) // block_size
29+
tn = (n + block_size - 1) // block_size
30+
tk = (k + block_size - 1) // block_size
31+
32+
lhs_mask = (rng.random((tm, tk)) >= sparsity).astype(np.bool_)
33+
rhs_mask = (rng.random((tk, tn)) >= sparsity).astype(np.bool_)
34+
out_mask = (rng.random((tm, tn)) >= sparsity).astype(np.bool_)
35+
return lhs_mask, rhs_mask, out_mask
36+
37+
38+
def mlx_naive_block_masked_mm(a, b, block_size, out_mask, lhs_mask, rhs_mask):
39+
"""MLX naive: expand masks and use regular matmul."""
40+
M, K = a.shape[-2], a.shape[-1]
41+
N = b.shape[-1]
42+
43+
def expand(mask, rows, cols):
44+
e = mx.repeat(mx.repeat(mask, block_size, axis=-2), block_size, axis=-1)
45+
return e[..., :rows, :cols]
46+
47+
a_masked = a * expand(lhs_mask, M, K)
48+
b_masked = b * expand(rhs_mask, K, N)
49+
c = a_masked @ b_masked
50+
c = c * expand(out_mask, M, N)
51+
return c
52+
53+
54+
def bench_mlx(fn, warmup, iters):
55+
for _ in range(warmup):
56+
y = fn()
57+
mx.eval(y)
58+
mx.synchronize()
59+
60+
start = time.perf_counter()
61+
for _ in range(iters):
62+
y = fn()
63+
mx.eval(y)
64+
mx.synchronize()
65+
return (time.perf_counter() - start) * 1e3 / iters
66+
67+
68+
def print_table(headers, rows):
69+
widths = [len(h) for h in headers]
70+
for row in rows:
71+
for i, cell in enumerate(row):
72+
widths[i] = max(widths[i], len(cell))
73+
74+
def fmt_row(row):
75+
return (
76+
"| "
77+
+ " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row))
78+
+ " |"
79+
)
80+
81+
sep = "|-" + "-|-".join("-" * w for w in widths) + "-|"
82+
print(fmt_row(headers))
83+
print(sep)
84+
for row in rows:
85+
print(fmt_row(row))
86+
87+
88+
def main():
89+
parser = argparse.ArgumentParser(
90+
description="Benchmark block_masked_mm vs naive expand+matmul"
91+
)
92+
parser.add_argument(
93+
"--cases",
94+
default=(
95+
"256x256x256x32x0.5,"
96+
"512x512x512x32x0.5,"
97+
"1024x1024x1024x32x0.5,"
98+
"1024x1024x1024x64x0.5,"
99+
"2048x2048x2048x64x0.5,"
100+
"256x256x256x32x0.0,"
101+
"1024x1024x1024x32x0.0,"
102+
"1024x1024x1024x32x0.9"
103+
),
104+
help="Comma-separated MxNxKxBSxSparsity list. Sparsity=fraction of blocks zeroed.",
105+
)
106+
parser.add_argument(
107+
"--dtype",
108+
default="float32",
109+
choices=["float16", "bfloat16", "float32"],
110+
)
111+
parser.add_argument("--warmup", type=int, default=10)
112+
parser.add_argument("--iters", type=int, default=50)
113+
parser.add_argument("--seed", type=int, default=42)
114+
parser.add_argument("--no-check", action="store_true")
115+
args = parser.parse_args()
116+
117+
mlx_dtype = MLX_DTYPES[args.dtype]
118+
119+
print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}")
120+
121+
headers = [
122+
"Case (MxNxKxBS)",
123+
"Sparsity",
124+
"MLX ms",
125+
"Naive ms",
126+
"Speedup",
127+
]
128+
if not args.no_check:
129+
headers.append("Max err")
130+
rows = []
131+
132+
cases = parse_cases(args.cases)
133+
for idx, (m, n, k, bs, sparsity) in enumerate(cases):
134+
rng = np.random.default_rng(args.seed + idx)
135+
a_np = rng.standard_normal((m, k)).astype(np.float32)
136+
b_np = rng.standard_normal((k, n)).astype(np.float32)
137+
lhs_mask_np, rhs_mask_np, out_mask_np = make_masks(m, n, k, bs, sparsity, rng)
138+
139+
a_mx = mx.array(a_np, dtype=mlx_dtype)
140+
b_mx = mx.array(b_np, dtype=mlx_dtype)
141+
lhs_mask_mx = mx.array(lhs_mask_np)
142+
rhs_mask_mx = mx.array(rhs_mask_np)
143+
out_mask_mx = mx.array(out_mask_np)
144+
mx.eval(a_mx, b_mx, lhs_mask_mx, rhs_mask_mx, out_mask_mx)
145+
146+
# Correctness check: block_masked_mm vs naive expand+matmul
147+
err_str = ""
148+
if not args.no_check:
149+
y_op = mx.block_masked_mm(
150+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
151+
)
152+
y_naive = mlx_naive_block_masked_mm(
153+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
154+
)
155+
mx.eval(y_op, y_naive)
156+
err = float(mx.max(mx.abs(y_op - y_naive)).item())
157+
err_str = f"{err:.2e}"
158+
159+
# Benchmark
160+
t_mlx = bench_mlx(
161+
lambda: mx.block_masked_mm(
162+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
163+
),
164+
args.warmup,
165+
args.iters,
166+
)
167+
t_naive = bench_mlx(
168+
lambda: mlx_naive_block_masked_mm(
169+
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
170+
),
171+
args.warmup,
172+
args.iters,
173+
)
174+
speedup = f"{t_naive / t_mlx:.2f}x" if t_mlx > 0 else "-"
175+
176+
row = [
177+
f"{m}x{n}x{k}x{bs}",
178+
f"{sparsity:.0%}",
179+
f"{t_mlx:.3f}",
180+
f"{t_naive:.3f}",
181+
speedup,
182+
]
183+
if not args.no_check:
184+
row.append(err_str)
185+
rows.append(row)
186+
187+
print_table(headers, rows)
188+
if not args.no_check:
189+
print("err: max|block_masked_mm - naive_expand_matmul|")
190+
191+
192+
if __name__ == "__main__":
193+
main()

0 commit comments

Comments
 (0)