|
6 | 6 | from devito.tools import WeakValueCache |
7 | 7 |
|
8 | 8 |
|
9 | | -__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator'] |
| 9 | +__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator', 'weak_instance_cache'] |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class memoized_func: |
@@ -131,79 +131,28 @@ def __call__(self, *args, **kwargs): |
131 | 131 | return result |
132 | 132 |
|
133 | 133 |
|
134 | | -# Describes the type of an object cached by `memoized_constructor` |
| 134 | +# Describes the type of an object cached by `weak_instance_cache` |
135 | 135 | InstanceType = TypeVar('InstanceType') |
136 | | -Constructor = Callable[..., InstanceType] |
| 136 | +Constructor = TypeVar('Constructor', bound=Callable[..., InstanceType]) |
137 | 137 |
|
138 | 138 |
|
139 | | -def _memoized_instances(cls: type[InstanceType]) -> type[InstanceType]: |
| 139 | +def weak_instance_cache(fun: Constructor[InstanceType]) -> Constructor[InstanceType]: |
140 | 140 | """ |
141 | | - Decorator for a class that caches instances based on the hash values of |
142 | | - constructing arguments. The constructed values are stored weakly and |
143 | | - evicted from the cache when no longer referenced. |
| 141 | + Decorator for a class method that caches instances based on the hash |
| 142 | + values of constructing arguments. The constructed values are stored |
| 143 | + weakly and evicted from the cache when no longer referenced. |
144 | 144 |
|
145 | | - We need to override both __new__ and __init__ to ensure initialization |
146 | | - only happens once for a cached instance. |
| 145 | + See `Dependence` for a usage example. |
147 | 146 | """ |
148 | 147 |
|
149 | | - # Check if we already decorated a parent class |
150 | | - already_applied = getattr(cls, '_memoized_instances__exists', False) |
151 | | - cls._memoized_instances__exists = True |
152 | | - |
153 | | - new = cls.__new__ |
154 | | - init = cls.__init__ |
155 | | - cache: WeakValueCache[InstanceType] = WeakValueCache(cls) |
156 | | - |
157 | | - @wraps(new) |
158 | | - def _new(_cls: type[InstanceType], *args: Hashable, |
159 | | - _memoized_instances__use_cache: bool = True, |
160 | | - **kwargs: Hashable) -> InstanceType: |
161 | | - # The decorator must be reapplied to a child class, so we make sure cls matches |
162 | | - if _memoized_instances__use_cache and _cls is not cls: |
163 | | - raise TypeError(f"_memoized_instances must be applied to {_cls.__name__}, " |
164 | | - f"not (just) {cls.__name__}") |
165 | | - |
166 | | - # The cache called the constructor; avoid infinite recursion |
167 | | - if not _memoized_instances__use_cache: |
168 | | - # If the class doesn't define __new__, we can't pass any args |
169 | | - if new is object.__new__: |
170 | | - obj = new(_cls) |
171 | | - |
172 | | - # Otherwise forward all arguments |
173 | | - else: |
174 | | - # If we applied the decorator to a parent class, forward the caching flag |
175 | | - if already_applied: |
176 | | - kwargs['_memoized_instances__use_cache'] = False |
177 | | - obj = new(_cls, *args, **kwargs) |
178 | | - |
179 | | - # Set our initialization flag and return |
180 | | - obj._memoized_instances__initialized = False |
181 | | - return obj |
182 | | - |
183 | | - return cache.get_or_create(*args, _memoized_instances__use_cache=False, **kwargs) |
184 | | - |
185 | | - @wraps(init) |
186 | | - def _init(self: InstanceType, *args: Hashable, **kwargs: Hashable) -> None: |
187 | | - # Skip reinitialization if this object was obtained from the cache |
188 | | - try: |
189 | | - if self._memoized_instances__initialized: |
190 | | - return |
191 | | - except AttributeError: |
192 | | - # If the attribute doesn't exist, this is a new instance |
193 | | - self._memoized_instances__initialized = False |
194 | | - |
195 | | - # Don't forward our extra argument to the original __init__ |
196 | | - kwargs.pop('_memoized_instances__use_cache', None) |
197 | | - init(self, *args, **kwargs) |
198 | | - self._memoized_instances__initialized = True |
| 148 | + # Initialize a weak value cache bound to this constructor |
| 149 | + cache = WeakValueCache(fun) |
199 | 150 |
|
200 | | - def _copy(self: InstanceType) -> InstanceType: |
201 | | - # Copy should just return the cached instance; bypass the cache machinery |
202 | | - return self |
| 151 | + @wraps(fun) |
| 152 | + def _wrapper(cls: type[InstanceType], *args: Hashable, **kwargs: Hashable) \ |
| 153 | + -> InstanceType: |
| 154 | + # Retrieve from the cache or create a new instance |
| 155 | + return cache.get_or_create(cls, *args, **kwargs) |
203 | 156 |
|
204 | | - # Update the class's methods |
205 | | - cls.__new__ = _new |
206 | | - cls.__init__ = _init |
207 | | - cls.__copy__ = _copy |
| 157 | + return _wrapper |
208 | 158 |
|
209 | | - return cls |
0 commit comments