Skip to content

Commit c717a03

Browse files
committed
misc: memoized_instances -> weak_instance_cache
1 parent 6175676 commit c717a03

2 files changed

Lines changed: 36 additions & 76 deletions

File tree

devito/ir/support/basic.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from itertools import chain, product
22
from functools import cached_property
3-
from typing import Any
3+
from typing import Any, TypeVar
44

55
from sympy import S
66
import sympy
@@ -13,8 +13,7 @@
1313
uxreplace)
1414
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1515
flatten, memoized_meth, memoized_generator, smart_gt,
16-
smart_lt)
17-
from devito.tools.memoization import _memoized_instances
16+
smart_lt, weak_instance_cache)
1817
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1918
CriticalRegion, Function, Symbol, Temp, TempArray,
2019
TBArray)
@@ -511,13 +510,26 @@ def touched_halo(self, findex):
511510
return (touch_halo_left, touch_halo_right)
512511

513512

514-
@_memoized_instances
513+
# Type variable for subclass of `Relation` returned from `maybe_cached`
514+
RelationType = TypeVar('RelationType', bound='Relation')
515+
516+
515517
class Relation:
516518

517519
"""
518520
A relation between two TimedAccess objects.
519521
"""
520522

523+
@classmethod
524+
@weak_instance_cache
525+
def maybe_cached(cls: type[RelationType],
526+
source: TimedAccess, sink: TimedAccess) -> RelationType:
527+
"""
528+
Constructs a new Relation or retrieves one from the cache if it
529+
already/still exists.
530+
"""
531+
return cls(source, sink)
532+
521533
def __init__(self, source, sink):
522534
assert isinstance(source, TimedAccess) and isinstance(sink, TimedAccess)
523535
assert source.function == sink.function
@@ -628,7 +640,6 @@ def is_imaginary(self):
628640
return S.ImaginaryUnit in self.distance
629641

630642

631-
@_memoized_instances
632643
class Dependence(Relation):
633644

634645
"""
@@ -1082,7 +1093,7 @@ def d_flow_gen(self):
10821093
if any(not rule(w, r) for rule in self.rules):
10831094
continue
10841095

1085-
dependence = Dependence(w, r)
1096+
dependence = Dependence.maybe_cached(w, r)
10861097

10871098
if dependence.is_imaginary:
10881099
continue
@@ -1112,7 +1123,7 @@ def d_anti_gen(self):
11121123
if any(not rule(r, w) for rule in self.rules):
11131124
continue
11141125

1115-
dependence = Dependence(r, w)
1126+
dependence = Dependence.maybe_cached(r, w)
11161127

11171128
if dependence.is_imaginary:
11181129
continue
@@ -1142,7 +1153,7 @@ def d_output_gen(self):
11421153
if any(not rule(w2, w1) for rule in self.rules):
11431154
continue
11441155

1145-
dependence = Dependence(w2, w1)
1156+
dependence = Dependence.maybe_cached(w2, w1)
11461157

11471158
if dependence.is_imaginary:
11481159
continue
@@ -1204,7 +1215,7 @@ def r_gen(self):
12041215
if a0 is a1:
12051216
continue
12061217

1207-
r = Relation(a0, a1)
1218+
r = Relation.maybe_cached(a0, a1)
12081219
if r.is_imaginary:
12091220
continue
12101221

devito/tools/memoization.py

Lines changed: 16 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito.tools import WeakValueCache
77

88

9-
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator']
9+
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator', 'weak_instance_cache']
1010

1111

1212
class memoized_func:
@@ -131,79 +131,28 @@ def __call__(self, *args, **kwargs):
131131
return result
132132

133133

134-
# Describes the type of an object cached by `memoized_constructor`
134+
# Describes the type of an object cached by `weak_instance_cache`
135135
InstanceType = TypeVar('InstanceType')
136-
Constructor = Callable[..., InstanceType]
136+
Constructor = TypeVar('Constructor', bound=Callable[..., InstanceType])
137137

138138

139-
def _memoized_instances(cls: type[InstanceType]) -> type[InstanceType]:
139+
def weak_instance_cache(fun: Constructor[InstanceType]) -> Constructor[InstanceType]:
140140
"""
141-
Decorator for a class that caches instances based on the hash values of
142-
constructing arguments. The constructed values are stored weakly and
143-
evicted from the cache when no longer referenced.
141+
Decorator for a class method that caches instances based on the hash
142+
values of constructing arguments. The constructed values are stored
143+
weakly and evicted from the cache when no longer referenced.
144144
145-
We need to override both __new__ and __init__ to ensure initialization
146-
only happens once for a cached instance.
145+
See `Dependence` for a usage example.
147146
"""
148147

149-
# Check if we already decorated a parent class
150-
already_applied = getattr(cls, '_memoized_instances__exists', False)
151-
cls._memoized_instances__exists = True
152-
153-
new = cls.__new__
154-
init = cls.__init__
155-
cache: WeakValueCache[InstanceType] = WeakValueCache(cls)
156-
157-
@wraps(new)
158-
def _new(_cls: type[InstanceType], *args: Hashable,
159-
_memoized_instances__use_cache: bool = True,
160-
**kwargs: Hashable) -> InstanceType:
161-
# The decorator must be reapplied to a child class, so we make sure cls matches
162-
if _memoized_instances__use_cache and _cls is not cls:
163-
raise TypeError(f"_memoized_instances must be applied to {_cls.__name__}, "
164-
f"not (just) {cls.__name__}")
165-
166-
# The cache called the constructor; avoid infinite recursion
167-
if not _memoized_instances__use_cache:
168-
# If the class doesn't define __new__, we can't pass any args
169-
if new is object.__new__:
170-
obj = new(_cls)
171-
172-
# Otherwise forward all arguments
173-
else:
174-
# If we applied the decorator to a parent class, forward the caching flag
175-
if already_applied:
176-
kwargs['_memoized_instances__use_cache'] = False
177-
obj = new(_cls, *args, **kwargs)
178-
179-
# Set our initialization flag and return
180-
obj._memoized_instances__initialized = False
181-
return obj
182-
183-
return cache.get_or_create(*args, _memoized_instances__use_cache=False, **kwargs)
184-
185-
@wraps(init)
186-
def _init(self: InstanceType, *args: Hashable, **kwargs: Hashable) -> None:
187-
# Skip reinitialization if this object was obtained from the cache
188-
try:
189-
if self._memoized_instances__initialized:
190-
return
191-
except AttributeError:
192-
# If the attribute doesn't exist, this is a new instance
193-
self._memoized_instances__initialized = False
194-
195-
# Don't forward our extra argument to the original __init__
196-
kwargs.pop('_memoized_instances__use_cache', None)
197-
init(self, *args, **kwargs)
198-
self._memoized_instances__initialized = True
148+
# Initialize a weak value cache bound to this constructor
149+
cache = WeakValueCache(fun)
199150

200-
def _copy(self: InstanceType) -> InstanceType:
201-
# Copy should just return the cached instance; bypass the cache machinery
202-
return self
151+
@wraps(fun)
152+
def _wrapper(cls: type[InstanceType], *args: Hashable, **kwargs: Hashable) \
153+
-> InstanceType:
154+
# Retrieve from the cache or create a new instance
155+
return cache.get_or_create(cls, *args, **kwargs)
203156

204-
# Update the class's methods
205-
cls.__new__ = _new
206-
cls.__init__ = _init
207-
cls.__copy__ = _copy
157+
return _wrapper
208158

209-
return cls

0 commit comments

Comments
 (0)