Skip to content

Commit 074ade3

Browse files
committed
Add free threaded build
1 parent 1091e3d commit 074ade3

5 files changed

Lines changed: 48 additions & 36 deletions

File tree

.github/actions/test-linux/action.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,27 @@ inputs:
99
runs:
1010
using: "composite"
1111
steps:
12+
# FIXME: The distributed tests fail with free-threading Python.
13+
- name: Check free-threading Python
14+
id: is-free-threading
15+
shell: bash
16+
run: |
17+
if python -VV 2>&1 | grep -q "free-threading"; then
18+
echo "result=true" >> $GITHUB_OUTPUT
19+
else
20+
echo "result=false" >> $GITHUB_OUTPUT
21+
fi
22+
1223
- name: Run MPI tests
24+
if: ${{ steps.is-free-threading.outputs.result == 'false' }}
1325
shell: bash
1426
run: |
1527
echo "::group::MPI tests"
1628
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
1729
echo "::endgroup::"
1830
1931
- name: Run distributed tests
20-
if: ${{ inputs.has-gpu == 'false' }}
32+
if: ${{ steps.is-free-threading.outputs.result == 'false' }}
2133
shell: bash
2234
run: |
2335
echo "::group::Distributed tests"

.github/workflows/nightly.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
strategy:
4242
fail-fast: false
4343
matrix:
44-
python_version: ["3.11", "3.12", "3.13", "3.14"]
44+
python_version: ["3.11", "3.12", "3.13", "3.14", "3.14t"]
4545
runner:
4646
- ubuntu-22.04
4747
- ubuntu-22.04-arm
@@ -59,7 +59,7 @@ jobs:
5959
if: github.repository == 'ml-explore/mlx'
6060
strategy:
6161
matrix:
62-
python-version: ["3.10", "3.13"]
62+
python-version: ["3.10", "3.13", "3.14t"]
6363
runs-on: [self-hosted, macos]
6464
steps:
6565
- uses: actions/checkout@v6

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
if: github.repository == 'ml-explore/mlx'
4848
strategy:
4949
matrix:
50-
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
50+
python_version: ["3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
5151
arch: ['x86_64', 'aarch64']
5252
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
5353
env:
@@ -83,7 +83,7 @@ jobs:
8383
if: github.repository == 'ml-explore/mlx'
8484
strategy:
8585
matrix:
86-
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
86+
python-version: ["3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
8787
runs-on: [self-hosted, macos]
8888
env:
8989
PYPI_RELEASE: 1

python/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ nanobind_add_module(
22
core
33
NB_STATIC
44
STABLE_ABI
5+
FREE_THREADED
56
LTO
67
NOMINSIZE
78
NB_DOMAIN

python/tests/mlx_distributed_tests.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -135,37 +135,36 @@ def test_shard_linear(self):
135135
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
136136
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
137137

138-
# And their quant versions (QuantizedMatmul is not supported on CUDA)
139-
if not mx.cuda.is_available():
140-
qlin = lin.to_quantized()
141-
slin1 = shard_linear(qlin, "all-to-sharded")
142-
slin2 = shard_linear(qlin, "sharded-to-all")
143-
y = qlin(x)
144-
y1 = slin1(x)
145-
y2 = slin2(x[part])
146-
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
147-
self.assertTrue(mx.allclose(y[part], y1))
148-
149-
# Test non-affine quantization modes (mxfp8)
150-
qlin_mxfp8 = lin.to_quantized(group_size=32, bits=8, mode="mxfp8")
151-
self.assertEqual(qlin_mxfp8.mode, "mxfp8")
152-
153-
slin1_mxfp8 = shard_linear(qlin_mxfp8, "all-to-sharded")
154-
slin2_mxfp8 = shard_linear(qlin_mxfp8, "sharded-to-all")
155-
156-
# Verify mode is propagated
157-
self.assertEqual(slin1_mxfp8.mode, "mxfp8")
158-
self.assertEqual(slin2_mxfp8.mode, "mxfp8")
159-
160-
# Verify biases parameter is not set for mxfp8
161-
self.assertIsNone(slin1_mxfp8.get("biases"))
162-
self.assertIsNone(slin2_mxfp8.get("biases"))
163-
164-
y = qlin_mxfp8(x)
165-
y1 = slin1_mxfp8(x)
166-
y2 = slin2_mxfp8(x[part])
167-
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
168-
self.assertTrue(mx.allclose(y[part], y1))
138+
# And their quant versions
139+
qlin = lin.to_quantized()
140+
slin1 = shard_linear(qlin, "all-to-sharded")
141+
slin2 = shard_linear(qlin, "sharded-to-all")
142+
y = qlin(x)
143+
y1 = slin1(x)
144+
y2 = slin2(x[part])
145+
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
146+
self.assertTrue(mx.allclose(y[part], y1))
147+
148+
# Test non-affine quantization modes (mxfp8)
149+
qlin_mxfp8 = lin.to_quantized(group_size=32, bits=8, mode="mxfp8")
150+
self.assertEqual(qlin_mxfp8.mode, "mxfp8")
151+
152+
slin1_mxfp8 = shard_linear(qlin_mxfp8, "all-to-sharded")
153+
slin2_mxfp8 = shard_linear(qlin_mxfp8, "sharded-to-all")
154+
155+
# Verify mode is propagated
156+
self.assertEqual(slin1_mxfp8.mode, "mxfp8")
157+
self.assertEqual(slin2_mxfp8.mode, "mxfp8")
158+
159+
# Verify biases parameter is not set for mxfp8
160+
self.assertIsNone(slin1_mxfp8.get("biases"))
161+
self.assertIsNone(slin2_mxfp8.get("biases"))
162+
163+
y = qlin_mxfp8(x)
164+
y1 = slin1_mxfp8(x)
165+
y2 = slin2_mxfp8(x[part])
166+
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
167+
self.assertTrue(mx.allclose(y[part], y1))
169168

170169
# Check the backward works as expected
171170
def dummy_loss(model, x, y):

0 commit comments

Comments
 (0)