Skip to content

Commit 6a9a121

Browse files
authored
Add a convenience for making local streams in python (#3355)
1 parent befe42d commit 6a9a121

7 files changed

Lines changed: 139 additions & 9 deletions

File tree

mlx/compile.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ class CompilerCache {
374374
cache_.clear();
375375
}
376376

377+
bool empty() {
378+
return cache_.empty();
379+
}
380+
377381
private:
378382
CompilerCache() {
379383
// Make sure the allocator is fully
@@ -1192,6 +1196,10 @@ void compile_clear_cache() {
11921196
detail::compiler_cache().clear();
11931197
}
11941198

1199+
bool compile_cache_empty() {
1200+
return detail::compiler_cache().empty();
1201+
}
1202+
11951203
} // namespace detail
11961204

11971205
std::function<std::vector<array>(const std::vector<array>&)> compile(

mlx/compile_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ MLX_API void compile_erase(std::uintptr_t fun_id);
3434
// when called again.
3535
MLX_API void compile_clear_cache();
3636

37+
// Return true if the cache is empty.
38+
MLX_API bool compile_cache_empty();
39+
3740
bool compile_available_for_device(const Device& device);
3841

3942
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>

python/mlx/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright © 2023 Apple Inc.
2+
23
from collections import defaultdict
34
from itertools import zip_longest
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

python/src/random.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ using namespace nb::literals;
1818

1919
class PyKeySequence {
2020
public:
21-
PyKeySequence() {
22-
// Destroy state before the python interpreter exits.
23-
auto atexit = nb::module_::import_("atexit");
24-
atexit.attr("register")(nb::cpp_function([this]() { state_.reset(); }));
21+
~PyKeySequence() {
22+
if (state_.has_value()) {
23+
nb::gil_scoped_acquire gil;
24+
state_.reset();
25+
}
26+
}
27+
28+
void reset() {
29+
state_.reset();
2530
}
2631

2732
void seed(uint64_t seed) {
@@ -521,4 +526,10 @@ void init_random(nb::module_& parent_module) {
521526
array:
522527
The generated random permutation or randomly permuted input array.
523528
)pbdoc");
529+
530+
// Ensure the main thread cleanup will happen before the interpreter goes
531+
// away. As a result if the other threads join the main thread we should have
532+
// a clean tear-down.
533+
auto atexit = nb::module_::import_("atexit");
534+
atexit.attr("register")(nb::cpp_function([]() { default_key().reset(); }));
524535
}

python/src/stream.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ class PyStreamContext {
4141
mx::StreamContext* _inner;
4242
};
4343

44+
class PyThreadLocalStream {
45+
public:
46+
PyThreadLocalStream(mx::Device d) : device(d) {}
47+
48+
mx::Stream stream() const {
49+
thread_local std::unordered_map<const PyThreadLocalStream*, mx::Stream>
50+
streams;
51+
52+
auto it = streams.find(this);
53+
if (it == streams.end()) {
54+
auto result = streams.emplace(this, mx::new_stream(device));
55+
it = result.first;
56+
}
57+
58+
return it->second;
59+
}
60+
61+
mx::Device device;
62+
};
63+
4464
void init_stream(nb::module_& m) {
4565
nb::class_<mx::Stream>(
4666
m,
@@ -49,6 +69,11 @@ void init_stream(nb::module_& m) {
4969
A stream for running operations on a given device.
5070
)pbdoc")
5171
.def_ro("device", &mx::Stream::device)
72+
.def(
73+
"__init__",
74+
[](mx::Stream* s, const PyThreadLocalStream& tls) {
75+
return new (s) mx::Stream(tls.stream());
76+
})
5277
.def(
5378
"__repr__",
5479
[](const mx::Stream& s) {
@@ -61,7 +86,29 @@ void init_stream(nb::module_& m) {
6186
s == nb::cast<mx::Stream>(other);
6287
});
6388

89+
nb::class_<PyThreadLocalStream>(
90+
m,
91+
"ThreadLocalStream",
92+
R"pbdoc(
93+
A stream that will be unique per thread and can be used to run operations on a given device.
94+
)pbdoc")
95+
.def_ro("device", &PyThreadLocalStream::device)
96+
.def(nb::init<mx::Device>())
97+
.def(
98+
"__repr__",
99+
[](const PyThreadLocalStream& s) {
100+
std::ostringstream os;
101+
os << "ThreadLocalStream(" << s.device << ")";
102+
return os.str();
103+
})
104+
.def("__eq__", [](const PyThreadLocalStream& s, const nb::object& other) {
105+
auto s_other = mx::default_stream(mx::default_device());
106+
return nb::try_cast<mx::Stream>(other, s_other) &&
107+
s_other == s.stream();
108+
});
109+
64110
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
111+
nb::implicitly_convertible<PyThreadLocalStream, mx::Stream>();
65112

66113
m.def(
67114
"default_stream",

python/src/transforms.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,12 +1463,22 @@ void init_transforms(nb::module_& m) {
14631463
bool shapeless) {
14641464
// Make sure each thread using mx.compile would clear its compile cache
14651465
// before python interpreter exits.
1466-
static thread_local auto clear_cache = []() {
1467-
auto atexit = nb::module_::import_("atexit");
1468-
atexit.attr("register")(
1469-
nb::cpp_function(&mx::detail::compile_clear_cache));
1470-
return true;
1466+
struct ThreadCleanup {
1467+
~ThreadCleanup() {
1468+
if (!mx::detail::compile_cache_empty()) {
1469+
nb::gil_scoped_acquire gil;
1470+
mx::detail::compile_clear_cache();
1471+
}
1472+
}
14711473
};
1474+
static thread_local auto clear_cache = []() {
1475+
// Ensure it is created
1476+
mx::detail::compile_clear_cache();
1477+
1478+
// Ensure it will be cleaned up
1479+
return ThreadCleanup{};
1480+
}();
1481+
14721482
return mlx_func(
14731483
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
14741484
fun,
@@ -1542,4 +1552,10 @@ void init_transforms(nb::module_& m) {
15421552
A callable that recomputes intermediate states during gradient
15431553
computation.
15441554
)pbdoc");
1555+
1556+
// Ensure the main thread cleanup will happen before the interpreter goes
1557+
// away. As a result if the other threads join the main thread we should have
1558+
// a clean tear-down.
1559+
auto atexit = nb::module_::import_("atexit");
1560+
atexit.attr("register")(nb::cpp_function(&mx::detail::compile_clear_cache));
15451561
}

python/tests/test_threads.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright © 2026 Apple Inc.
2+
3+
import threading
4+
import unittest
5+
6+
import mlx.core as mx
7+
import mlx_tests
8+
9+
10+
class TestReduce(mlx_tests.MLXTestCase):
11+
def test_threadlocal_stream(self):
12+
test_stream = mx.new_stream(mx.default_device())
13+
14+
def test_failure():
15+
with self.assertRaises(RuntimeError):
16+
with mx.stream(test_stream):
17+
x = mx.arange(10)
18+
mx.eval(2 * x)
19+
20+
t1 = threading.Thread(target=test_failure)
21+
t2 = threading.Thread(target=test_failure)
22+
t1.start()
23+
t2.start()
24+
t1.join()
25+
t2.join()
26+
27+
test_stream = mx.ThreadLocalStream(mx.default_device())
28+
29+
def test_success():
30+
with mx.stream(test_stream):
31+
x = mx.arange(10)
32+
mx.eval(2 * x)
33+
self.assertEqual(x.tolist(), list(range(10)))
34+
35+
t1 = threading.Thread(target=test_success)
36+
t2 = threading.Thread(target=test_success)
37+
t1.start()
38+
t2.start()
39+
t1.join()
40+
t2.join()
41+
42+
43+
if __name__ == "__main__":
44+
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)