Skip to content

Commit 7b2269b

Browse files
committed
update caching for free-threaded python compatibility
1 parent 29c7379 commit 7b2269b

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

dpctl/_sycl_device_factory.pyx

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,27 @@ cdef class _DefaultDeviceCache:
439439
return _copy
440440

441441

442+
# no default, as would share a single mutable instance across threads and
443+
# concurrent access to the cache would not be thread-safe. Using ContextVar
444+
# without a default ensures each context gets its own instance.
442445
_global_default_device_cache = ContextVar(
443446
"global_default_device_cache",
444-
default=_DefaultDeviceCache()
445447
)
446448

447449

450+
cdef _DefaultDeviceCache _get_default_device_cache():
451+
"""
452+
Factory function to get or create a default device cache for the current
453+
context
454+
"""
455+
try:
456+
return _global_default_device_cache.get()
457+
except LookupError:
458+
cache = _DefaultDeviceCache()
459+
_global_default_device_cache.set(cache)
460+
return cache
461+
462+
448463
cpdef SyclDevice _cached_default_device():
449464
"""Returns a cached device selected by default selector.
450465
@@ -453,7 +468,7 @@ cpdef SyclDevice _cached_default_device():
453468
A cached default-selected SYCL device.
454469
455470
"""
456-
cdef _DefaultDeviceCache _cache = _global_default_device_cache.get()
471+
cdef _DefaultDeviceCache _cache = _get_default_device_cache()
457472
d_, changed_ = _cache.get_or_create()
458473
if changed_:
459474
_global_default_device_cache.set(_cache)

dpctl/_sycl_queue_manager.pyx

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,27 @@ cdef class _DeviceDefaultQueueCache:
7575
return _copy
7676

7777

78+
# no default, as would share a single mutable instance across threads and
79+
# concurrent access to the cache would not be thread-safe. Using ContextVar
80+
# without a default ensures each context gets its own instance.
7881
_global_device_queue_cache = ContextVar(
7982
"global_device_queue_cache",
80-
default=_DeviceDefaultQueueCache()
8183
)
8284

8385

86+
cdef _DeviceDefaultQueueCache _get_device_queue_cache():
87+
"""
88+
Factory function to get or create a default device queue cache for the
89+
current context
90+
"""
91+
try:
92+
return _global_device_queue_cache.get()
93+
except LookupError:
94+
cache = _DeviceDefaultQueueCache()
95+
_global_device_queue_cache.set(cache)
96+
return cache
97+
98+
8499
cpdef object get_device_cached_queue(object key):
85100
"""Returns a cached queue associated with given device.
86101
@@ -97,7 +112,7 @@ cpdef object get_device_cached_queue(object key):
97112
TypeError: If the input key is not one of the accepted types.
98113
99114
"""
100-
_cache = _global_device_queue_cache.get()
115+
_cache = _get_device_queue_cache()
101116
q_, changed_ = _cache.get_or_create(key)
102117
if changed_:
103118
_global_device_queue_cache.set(_cache)

0 commit comments

Comments
 (0)