@@ -41,6 +41,26 @@ class PyStreamContext {
4141 mx::StreamContext* _inner;
4242};
4343
44+ class PyThreadLocalStream {
45+ public:
46+ PyThreadLocalStream (mx::Device d) : device(d) {}
47+
48+ mx::Stream stream () const {
49+ thread_local std::unordered_map<const PyThreadLocalStream*, mx::Stream>
50+ streams;
51+
52+ auto it = streams.find (this );
53+ if (it == streams.end ()) {
54+ auto result = streams.emplace (this , mx::new_stream (device));
55+ it = result.first ;
56+ }
57+
58+ return it->second ;
59+ }
60+
61+ mx::Device device;
62+ };
63+
4464void init_stream (nb::module_& m) {
4565 nb::class_<mx::Stream>(
4666 m,
@@ -49,6 +69,11 @@ void init_stream(nb::module_& m) {
4969 A stream for running operations on a given device.
5070 )pbdoc" )
5171 .def_ro (" device" , &mx::Stream::device)
72+ .def (
73+ " __init__" ,
74+ [](mx::Stream* s, const PyThreadLocalStream& tls) {
75+ return new (s) mx::Stream (tls.stream ());
76+ })
5277 .def (
5378 " __repr__" ,
5479 [](const mx::Stream& s) {
@@ -61,7 +86,29 @@ void init_stream(nb::module_& m) {
6186 s == nb::cast<mx::Stream>(other);
6287 });
6388
89+ nb::class_<PyThreadLocalStream>(
90+ m,
91+ " ThreadLocalStream" ,
92+ R"pbdoc(
93+ A stream that will be unique per thread and can be used to run operations on a given device.
94+ )pbdoc" )
95+ .def_ro (" device" , &PyThreadLocalStream::device)
96+ .def (nb::init<mx::Device>())
97+ .def (
98+ " __repr__" ,
99+ [](const PyThreadLocalStream& s) {
100+ std::ostringstream os;
101+ os << " ThreadLocalStream(" << s.device << " )" ;
102+ return os.str ();
103+ })
104+ .def (" __eq__" , [](const PyThreadLocalStream& s, const nb::object& other) {
105+ auto s_other = mx::default_stream (mx::default_device ());
106+ return nb::try_cast<mx::Stream>(other, s_other) &&
107+ s_other == s.stream ();
108+ });
109+
64110 nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
111+ nb::implicitly_convertible<PyThreadLocalStream, mx::Stream>();
65112
66113 m.def (
67114 " default_stream" ,
0 commit comments