File tree Expand file tree Collapse file tree 2 files changed +34
-4
lines changed
Expand file tree Collapse file tree 2 files changed +34
-4
lines changed Original file line number Diff line number Diff line change @@ -439,12 +439,27 @@ cdef class _DefaultDeviceCache:
439439 return _copy
440440
441441
442+ # no default, as would share a single mutable instance across threads and
443+ # concurrent access to the cache would not be thread-safe. Using ContextVar
444+ # without a default ensures each context gets its own instance.
442445_global_default_device_cache = ContextVar(
443446 " global_default_device_cache" ,
444- default = _DefaultDeviceCache()
445447)
446448
447449
450+ cdef _DefaultDeviceCache _get_default_device_cache():
451+ """
452+ Factory function to get or create a default device cache for the current
453+ context
454+ """
455+ try :
456+ return _global_default_device_cache.get()
457+ except LookupError :
458+ cache = _DefaultDeviceCache()
459+ _global_default_device_cache.set(cache)
460+ return cache
461+
462+
448463cpdef SyclDevice _cached_default_device():
449464 """ Returns a cached device selected by default selector.
450465
@@ -453,7 +468,7 @@ cpdef SyclDevice _cached_default_device():
453468 A cached default-selected SYCL device.
454469
455470 """
456- cdef _DefaultDeviceCache _cache = _global_default_device_cache.get ()
471+ cdef _DefaultDeviceCache _cache = _get_default_device_cache ()
457472 d_, changed_ = _cache.get_or_create()
458473 if changed_:
459474 _global_default_device_cache.set(_cache)
Original file line number Diff line number Diff line change @@ -75,12 +75,27 @@ cdef class _DeviceDefaultQueueCache:
7575 return _copy
7676
7777
78+ # no default, as would share a single mutable instance across threads and
79+ # concurrent access to the cache would not be thread-safe. Using ContextVar
80+ # without a default ensures each context gets its own instance.
7881_global_device_queue_cache = ContextVar(
7982 " global_device_queue_cache" ,
80- default = _DeviceDefaultQueueCache()
8183)
8284
8385
86+ cdef _DeviceDefaultQueueCache _get_device_queue_cache():
87+ """
88+ Factory function to get or create a default device queue cache for the
89+ current context
90+ """
91+ try :
92+ return _global_device_queue_cache.get()
93+ except LookupError :
94+ cache = _DeviceDefaultQueueCache()
95+ _global_device_queue_cache.set(cache)
96+ return cache
97+
98+
8499cpdef object get_device_cached_queue(object key):
85100 """ Returns a cached queue associated with given device.
86101
@@ -97,7 +112,7 @@ cpdef object get_device_cached_queue(object key):
97112 TypeError: If the input key is not one of the accepted types.
98113
99114 """
100- _cache = _global_device_queue_cache.get ()
115+ _cache = _get_device_queue_cache ()
101116 q_, changed_ = _cache.get_or_create(key)
102117 if changed_:
103118 _global_device_queue_cache.set(_cache)
You can’t perform that action at this time.
0 commit comments