|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import threading |
| 16 | +import weakref |
16 | 17 | from typing import TYPE_CHECKING, Any, Optional |
17 | 18 |
|
18 | 19 | import nvtx |
@@ -96,14 +97,16 @@ def __init__( |
96 | 97 | else None |
97 | 98 | ) |
98 | 99 |
|
99 | | - def __del__(self): |
100 | | - self._return_op_to_cache() |
| 100 | + if hasattr(self._operator, "_cache"): |
| 101 | + self._return_op_to_cache = weakref.finalize( |
| 102 | + self, Invocation._return_op_to_cache_impl, self._operator |
| 103 | + ) |
| 104 | + else: |
| 105 | + self._return_op_to_cache = None |
101 | 106 |
|
102 | | - def _return_op_to_cache(self): |
103 | | - if (cache := getattr(self._operator, "_cache", None)) is not None: |
104 | | - cache[self._operator._key] = self._operator |
105 | | - self._operator = None |
106 | | - self._return_op_to_cache = lambda: None |
| 107 | + @staticmethod |
| 108 | + def _return_op_to_cache_impl(op): |
| 109 | + op._cache[op._key] = op |
107 | 110 |
|
108 | 111 | def device(self, result_index: int): |
109 | 112 | if self._output_devices is None: |
@@ -305,7 +308,9 @@ def output_device(x): |
305 | 308 | assert output_device(self._results[i]) == d |
306 | 309 |
|
307 | 310 | ctx.cache_results(self, self._results) |
308 | | - self._return_op_to_cache() # the operator instance is ready for a new invocation |
| 311 | + if self._return_op_to_cache: |
| 312 | + self._return_op_to_cache() # the operator instance is ready for a new invocation |
| 313 | + self._operator = None |
309 | 314 |
|
310 | 315 | def values(self, ctx: Optional[_EvalContext] = None): |
311 | 316 | """ |
|
0 commit comments