Skip to content

Commit 65df8d3

Browse files
committed
Make ordermanager free-threading safe
1 parent f77ec0a commit 65df8d3

File tree

3 files changed

+75
-45
lines changed

3 files changed

+75
-45
lines changed

dpctl/_sycl_queue.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cdef public api class _SyclQueue [
4343
cdef DPCTLSyclQueueRef _queue_ref
4444
cdef SyclContext _context
4545
cdef SyclDevice _device
46+
cdef object __weakref__
4647

4748

4849
cdef public api class SyclQueue (_SyclQueue) [

dpctl/utils/_order_manager.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class _SequentialOrderManager:
1616
def __init__(self):
1717
self._state = _OrderManager(16)
1818

19-
def __dealloc__(self):
19+
def __del__(self):
2020
_local = self._state
2121
SyclEvent.wait_for(_local.get_submitted_events())
2222
SyclEvent.wait_for(_local.get_host_task_events())
@@ -71,24 +71,31 @@ class SyclQueueToOrderManagerMap:
7171
def __init__(self):
7272
self._map = ContextVar(
7373
"global_order_manager_map",
74-
default=defaultdict(_SequentialOrderManager),
74+
# no default to avoid sharing a single defaultdict
75+
# across threads
7576
)
7677

78+
def _get_map(self):
79+
"""
80+
Factory method to get or create a default device queue cache for the
81+
current context
82+
"""
83+
try:
84+
return self._map.get()
85+
except LookupError:
86+
m = defaultdict(_SequentialOrderManager)
87+
self._map.set(m)
88+
return m
89+
7790
def __getitem__(self, q: SyclQueue) -> _SequentialOrderManager:
7891
"""Get order manager for given SyclQueue"""
79-
_local = self._map.get()
8092
if not isinstance(q, SyclQueue):
8193
raise TypeError(f"Expected `dpctl.SyclQueue`, got {type(q)}")
82-
if q in _local:
83-
return _local[q]
84-
else:
85-
v = _local[q]
86-
_local[q] = v
87-
return v
94+
return self._get_map()[q]
8895

8996
def clear(self):
9097
"""Clear content of internal dictionary"""
91-
_local = self._map.get()
98+
_local = self._get_map()
9299
for v in _local.values():
93100
v.wait()
94101
_local.clear()

dpctl/utils/src/sequential_order_keeper.hpp

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <algorithm>
55
#include <cstddef>
6+
#include <mutex>
67
#include <vector>
78

89
namespace
@@ -21,10 +22,12 @@ inline bool is_event_complete(const sycl::event &e)
2122
class SequentialOrder
2223
{
2324
private:
25+
mutable std::mutex mu_events;
2426
std::vector<sycl::event> host_task_events;
2527
std::vector<sycl::event> submitted_events;
2628

27-
void prune_complete()
29+
// only called with mu_events held
30+
void prune_complete_nolock()
2831
{
2932
const auto &ht_it =
3033
std::remove_if(host_task_events.begin(), host_task_events.end(),
@@ -46,76 +49,78 @@ class SequentialOrder
4649
}
4750

4851
SequentialOrder(const SequentialOrder &other)
49-
: host_task_events(other.host_task_events),
50-
submitted_events(other.submitted_events)
5152
{
52-
prune_complete();
53+
std::lock_guard<std::mutex> lock(other.mu_events);
54+
host_task_events = other.host_task_events;
55+
submitted_events = other.submitted_events;
56+
prune_complete_nolock();
5357
}
5458
SequentialOrder(SequentialOrder &&other)
5559
: host_task_events{}, submitted_events{}
5660
{
61+
std::lock_guard<std::mutex> lock(other.mu_events);
5762
host_task_events = std::move(other.host_task_events);
5863
submitted_events = std::move(other.submitted_events);
59-
prune_complete();
64+
prune_complete_nolock();
6065
}
6166

6267
SequentialOrder &operator=(const SequentialOrder &other)
6368
{
64-
host_task_events = other.host_task_events;
65-
submitted_events = other.submitted_events;
66-
67-
prune_complete();
69+
if (this != &other) {
70+
std::scoped_lock lock(mu_events, other.mu_events);
71+
host_task_events = other.host_task_events;
72+
submitted_events = other.submitted_events;
73+
prune_complete_nolock();
74+
}
6875
return *this;
6976
}
7077

7178
SequentialOrder &operator=(SequentialOrder &&other)
7279
{
7380
if (this != &other) {
81+
std::scoped_lock lock(mu_events, other.mu_events);
7482
host_task_events = std::move(other.host_task_events);
7583
submitted_events = std::move(other.submitted_events);
76-
prune_complete();
84+
prune_complete_nolock();
7785
}
7886
return *this;
7987
}
8088

8189
std::size_t get_num_submitted_events() const
8290
{
91+
std::lock_guard<std::mutex> lock(mu_events);
8392
return submitted_events.size();
8493
}
8594

86-
const std::vector<sycl::event> &get_host_task_events()
95+
// returns a copy to avoid returning a reference that
96+
// could be modified after the lock is released
97+
std::vector<sycl::event> get_host_task_events()
8798
{
88-
prune_complete();
99+
std::lock_guard<std::mutex> lock(mu_events);
100+
prune_complete_nolock();
89101
return host_task_events;
90102
}
91103

92-
/*
93-
const std::vector<sycl::event> & get_host_task_events() const {
94-
return host_task_events;
95-
}
96-
*/
97-
98104
std::size_t get_num_host_task_events() const
99105
{
106+
std::lock_guard<std::mutex> lock(mu_events);
100107
return host_task_events.size();
101108
}
102109

103-
const std::vector<sycl::event> &get_submitted_events()
110+
// returns a copy to avoid returning a reference that
111+
// could be modified after the lock is released
112+
std::vector<sycl::event> get_submitted_events()
104113
{
105-
prune_complete();
114+
std::lock_guard<std::mutex> lock(mu_events);
115+
prune_complete_nolock();
106116
return submitted_events;
107117
}
108118

109-
/*
110-
const std::vector<sycl::event> & get_submitted_events() const {
111-
return submitted_events;
112-
}
113-
*/
114-
115119
void add_to_both_events(const sycl::event &ht_ev,
116120
const sycl::event &comp_ev)
117121
{
118-
prune_complete();
122+
std::lock_guard<std::mutex> lock(mu_events);
123+
prune_complete_nolock();
119124
if (!is_event_complete(ht_ev))
120125
host_task_events.push_back(ht_ev);
121126
if (!is_event_complete(comp_ev))
@@ -125,7 +130,8 @@ class SequentialOrder
125130
void add_vector_to_both_events(const std::vector<sycl::event> &ht_evs,
126131
const std::vector<sycl::event> &comp_evs)
127132
{
128-
prune_complete();
133+
std::lock_guard<std::mutex> lock(mu_events);
134+
prune_complete_nolock();
129135
for (const auto &e : ht_evs) {
130136
if (!is_event_complete(e))
131137
host_task_events.push_back(e);
@@ -138,15 +144,17 @@ class SequentialOrder
138144

139145
void add_to_host_task_events(const sycl::event &ht_ev)
140146
{
141-
prune_complete();
147+
std::lock_guard<std::mutex> lock(mu_events);
148+
prune_complete_nolock();
142149
if (!is_event_complete(ht_ev)) {
143150
host_task_events.push_back(ht_ev);
144151
}
145152
}
146153

147154
void add_to_submitted_events(const sycl::event &comp_ev)
148155
{
149-
prune_complete();
156+
std::lock_guard<std::mutex> lock(mu_events);
157+
prune_complete_nolock();
150158
if (!is_event_complete(comp_ev)) {
151159
submitted_events.push_back(comp_ev);
152160
}
@@ -155,7 +163,8 @@ class SequentialOrder
155163
template <std::size_t num>
156164
void add_list_to_host_task_events(const sycl::event (&ht_events)[num])
157165
{
158-
prune_complete();
166+
std::lock_guard<std::mutex> lock(mu_events);
167+
prune_complete_nolock();
159168
for (std::size_t i = 0; i < num; ++i) {
160169
const auto &e = ht_events[i];
161170
if (!is_event_complete(e))
@@ -166,7 +175,8 @@ class SequentialOrder
166175
template <std::size_t num>
167176
void add_list_to_submitted_events(const sycl::event (&comp_events)[num])
168177
{
169-
prune_complete();
178+
std::lock_guard<std::mutex> lock(mu_events);
179+
prune_complete_nolock();
170180
for (std::size_t i = 0; i < num; ++i) {
171181
const auto &e = comp_events[i];
172182
if (!is_event_complete(e))
@@ -176,8 +186,20 @@ class SequentialOrder
176186

177187
void wait()
178188
{
179-
sycl::event::wait(submitted_events);
180-
sycl::event::wait(host_task_events);
181-
prune_complete();
189+
// snapeshot events outside of mutex to avoid
190+
// calling wait inside mutex
191+
std::vector<sycl::event> sub_copy;
192+
std::vector<sycl::event> ht_copy;
193+
{
194+
std::lock_guard<std::mutex> lock(mu_events);
195+
sub_copy = submitted_events;
196+
ht_copy = host_task_events;
197+
}
198+
sycl::event::wait(sub_copy);
199+
sycl::event::wait(ht_copy);
200+
{
201+
std::lock_guard<std::mutex> lock(mu_events);
202+
prune_complete_nolock();
203+
}
182204
}
183205
};

0 commit comments

Comments
 (0)