Skip to content

Commit 09c710c

Browse files
committed
misc: memoized_instances -> weak_instance_cache
1 parent 6175676 commit 09c710c

3 files changed

Lines changed: 48 additions & 98 deletions

File tree

devito/ir/support/basic.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from itertools import chain, product
22
from functools import cached_property
3-
from typing import Any
3+
from typing import Any, TypeVar
44

55
from sympy import S
66
import sympy
@@ -13,8 +13,7 @@
1313
uxreplace)
1414
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1515
flatten, memoized_meth, memoized_generator, smart_gt,
16-
smart_lt)
17-
from devito.tools.memoization import _memoized_instances
16+
smart_lt, weak_instance_cache)
1817
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1918
CriticalRegion, Function, Symbol, Temp, TempArray,
2019
TBArray)
@@ -511,13 +510,26 @@ def touched_halo(self, findex):
511510
return (touch_halo_left, touch_halo_right)
512511

513512

514-
@_memoized_instances
513+
# Type variable for subclass of `Relation` returned from `maybe_cached`
514+
RelationType = TypeVar('RelationType', bound='Relation')
515+
516+
515517
class Relation:
516518

517519
"""
518520
A relation between two TimedAccess objects.
519521
"""
520522

523+
@classmethod
524+
@weak_instance_cache
525+
def maybe_cached(cls: type[RelationType],
526+
source: TimedAccess, sink: TimedAccess) -> RelationType:
527+
"""
528+
Constructs a new Relation or retrieves one from the cache if it
529+
already/still exists.
530+
"""
531+
return cls(source, sink)
532+
521533
def __init__(self, source, sink):
522534
assert isinstance(source, TimedAccess) and isinstance(sink, TimedAccess)
523535
assert source.function == sink.function
@@ -628,7 +640,6 @@ def is_imaginary(self):
628640
return S.ImaginaryUnit in self.distance
629641

630642

631-
@_memoized_instances
632643
class Dependence(Relation):
633644

634645
"""
@@ -1082,7 +1093,7 @@ def d_flow_gen(self):
10821093
if any(not rule(w, r) for rule in self.rules):
10831094
continue
10841095

1085-
dependence = Dependence(w, r)
1096+
dependence = Dependence.maybe_cached(w, r)
10861097

10871098
if dependence.is_imaginary:
10881099
continue
@@ -1112,7 +1123,7 @@ def d_anti_gen(self):
11121123
if any(not rule(r, w) for rule in self.rules):
11131124
continue
11141125

1115-
dependence = Dependence(r, w)
1126+
dependence = Dependence.maybe_cached(r, w)
11161127

11171128
if dependence.is_imaginary:
11181129
continue
@@ -1142,7 +1153,7 @@ def d_output_gen(self):
11421153
if any(not rule(w2, w1) for rule in self.rules):
11431154
continue
11441155

1145-
dependence = Dependence(w2, w1)
1156+
dependence = Dependence.maybe_cached(w2, w1)
11461157

11471158
if dependence.is_imaginary:
11481159
continue
@@ -1204,7 +1215,7 @@ def r_gen(self):
12041215
if a0 is a1:
12051216
continue
12061217

1207-
r = Relation(a0, a1)
1218+
r = Relation.maybe_cached(a0, a1)
12081219
if r.is_imaginary:
12091220
continue
12101221

devito/tools/abc.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
from collections.abc import Callable
23
import weakref
34
from concurrent.futures import Future
45
from hashlib import sha1
@@ -289,8 +290,9 @@ def __repr__(self):
289290
__str__ = __repr__
290291

291292

292-
# Cached instance type for `WeakValueCache` (not covariant to avoid weird init logic)
293-
ValueType = TypeVar('ValueType')
293+
# Cached instance type for `WeakValueCache`
294+
ValueType = TypeVar('ValueType', covariant=True)
295+
Constructor = Callable[..., ValueType]
294296

295297

296298
class WeakValueCache(Generic[ValueType]):
@@ -304,34 +306,22 @@ class WeakValueCache(Generic[ValueType]):
304306
concurrent access while still allowing for threaded construction.
305307
"""
306308

307-
def __init__(self, cls: type[ValueType]):
308-
self._cls = cls
309+
def __init__(self, constructor: Constructor[ValueType]):
310+
self._constructor = constructor
309311
self._futures: dict[int, Future[ReferenceType[ValueType]]] = {}
310312
self._lock = RLock()
311313

312-
def _make_key(self, *args: Hashable, **kwargs: Hashable) -> int:
313-
return hash((*args, frozenset(kwargs.items())))
314+
def _make_key(self, cls: type[ValueType], *args: Hashable, **kwargs: Hashable) -> int:
315+
return hash((cls, *args, frozenset(kwargs.items())))
314316

315-
def _create_instance(self, *args: Hashable, **kwargs: Hashable) -> ValueType:
316-
if self._cls is object.__new__:
317-
# If the constructor is object's __new__, we cannot pass any arguments
318-
obj = self._cls()
319-
else:
320-
# Otherwise, forward all construction arguments
321-
obj = self._cls(*args, **kwargs)
322-
323-
# Initialize the object so it's ready for consuming threads
324-
obj.__init__(*args, **kwargs)
325-
326-
return obj
327-
328-
def get_or_create(self, *args: Hashable, **kwargs: Hashable) -> ValueType:
317+
def get_or_create(self, cls: type[ValueType],
318+
*args: Hashable, **kwargs: Hashable) -> ValueType:
329319
"""
330320
Gets an instance for the given construction arguments, creating it on this thread
331321
if it doesn't exist. If another thread is currently constructing it, blocks until
332322
the instance is available.
333323
"""
334-
key = self._make_key(*args, **kwargs)
324+
key = self._make_key(cls, *args, **kwargs)
335325
future: Future[ReferenceType[ValueType]]
336326

337327
while True:
@@ -360,7 +350,7 @@ def get_or_create(self, *args: Hashable, **kwargs: Hashable) -> ValueType:
360350
# If we got here, this is the thread that will create the new instance
361351
# Do this outside the lock to allow for concurrent construction
362352
try:
363-
obj = self._create_instance(*args, **kwargs)
353+
obj = self._constructor(cls, *args, **kwargs)
364354

365355
# Listener for when the weak reference expires
366356
def on_obj_destroyed(k: int = key,

devito/tools/memoization.py

Lines changed: 16 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito.tools import WeakValueCache
77

88

9-
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator']
9+
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator', 'weak_instance_cache']
1010

1111

1212
class memoized_func:
@@ -131,79 +131,28 @@ def __call__(self, *args, **kwargs):
131131
return result
132132

133133

134-
# Describes the type of an object cached by `memoized_constructor`
134+
# Describes the type of an object cached by `weak_instance_cache`
135135
InstanceType = TypeVar('InstanceType')
136-
Constructor = Callable[..., InstanceType]
136+
Constructor = TypeVar('Constructor', bound=Callable[..., InstanceType])
137137

138138

139-
def _memoized_instances(cls: type[InstanceType]) -> type[InstanceType]:
139+
def weak_instance_cache(fun: Constructor[InstanceType]) -> Constructor[InstanceType]:
140140
"""
141-
Decorator for a class that caches instances based on the hash values of
142-
constructing arguments. The constructed values are stored weakly and
143-
evicted from the cache when no longer referenced.
141+
Decorator for a class method that caches instances based on the hash
142+
values of constructing arguments. The constructed values are stored
143+
weakly and evicted from the cache when no longer referenced.
144144
145-
We need to override both __new__ and __init__ to ensure initialization
146-
only happens once for a cached instance.
145+
See `Dependence` for a usage example.
147146
"""
148147

149-
# Check if we already decorated a parent class
150-
already_applied = getattr(cls, '_memoized_instances__exists', False)
151-
cls._memoized_instances__exists = True
152-
153-
new = cls.__new__
154-
init = cls.__init__
155-
cache: WeakValueCache[InstanceType] = WeakValueCache(cls)
156-
157-
@wraps(new)
158-
def _new(_cls: type[InstanceType], *args: Hashable,
159-
_memoized_instances__use_cache: bool = True,
160-
**kwargs: Hashable) -> InstanceType:
161-
# The decorator must be reapplied to a child class, so we make sure cls matches
162-
if _memoized_instances__use_cache and _cls is not cls:
163-
raise TypeError(f"_memoized_instances must be applied to {_cls.__name__}, "
164-
f"not (just) {cls.__name__}")
165-
166-
# The cache called the constructor; avoid infinite recursion
167-
if not _memoized_instances__use_cache:
168-
# If the class doesn't define __new__, we can't pass any args
169-
if new is object.__new__:
170-
obj = new(_cls)
171-
172-
# Otherwise forward all arguments
173-
else:
174-
# If we applied the decorator to a parent class, forward the caching flag
175-
if already_applied:
176-
kwargs['_memoized_instances__use_cache'] = False
177-
obj = new(_cls, *args, **kwargs)
178-
179-
# Set our initialization flag and return
180-
obj._memoized_instances__initialized = False
181-
return obj
182-
183-
return cache.get_or_create(*args, _memoized_instances__use_cache=False, **kwargs)
184-
185-
@wraps(init)
186-
def _init(self: InstanceType, *args: Hashable, **kwargs: Hashable) -> None:
187-
# Skip reinitialization if this object was obtained from the cache
188-
try:
189-
if self._memoized_instances__initialized:
190-
return
191-
except AttributeError:
192-
# If the attribute doesn't exist, this is a new instance
193-
self._memoized_instances__initialized = False
194-
195-
# Don't forward our extra argument to the original __init__
196-
kwargs.pop('_memoized_instances__use_cache', None)
197-
init(self, *args, **kwargs)
198-
self._memoized_instances__initialized = True
148+
# Initialize a weak value cache bound to this constructor
149+
cache = WeakValueCache(fun)
199150

200-
def _copy(self: InstanceType) -> InstanceType:
201-
# Copy should just return the cached instance; bypass the cache machinery
202-
return self
151+
@wraps(fun)
152+
def _wrapper(cls: type[InstanceType], *args: Hashable, **kwargs: Hashable) \
153+
-> InstanceType:
154+
# Retrieve from the cache or create a new instance
155+
return cache.get_or_create(cls, *args, **kwargs)
203156

204-
# Update the class's methods
205-
cls.__new__ = _new
206-
cls.__init__ = _init
207-
cls.__copy__ = _copy
157+
return _wrapper
208158

209-
return cls

0 commit comments

Comments
 (0)