Skip to content

Commit 1d6d57f

Browse files
authored
Cleanup after instance cache rework. (#6209)
* Cleanup after instance cache rework. * Improve instance finalizer. --------- Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
1 parent f66e814 commit 1d6d57f

2 files changed

Lines changed: 21 additions & 14 deletions

File tree

dali/python/nvidia/dali/experimental/dynamic/_invocation.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import threading
16+
import weakref
1617
from typing import TYPE_CHECKING, Any, Optional
1718

1819
import nvtx
@@ -96,14 +97,16 @@ def __init__(
9697
else None
9798
)
9899

99-
def __del__(self):
100-
self._return_op_to_cache()
100+
if hasattr(self._operator, "_cache"):
101+
self._return_op_to_cache = weakref.finalize(
102+
self, Invocation._return_op_to_cache_impl, self._operator
103+
)
104+
else:
105+
self._return_op_to_cache = None
101106

102-
def _return_op_to_cache(self):
103-
if (cache := getattr(self._operator, "_cache", None)) is not None:
104-
cache[self._operator._key] = self._operator
105-
self._operator = None
106-
self._return_op_to_cache = lambda: None
107+
@staticmethod
108+
def _return_op_to_cache_impl(op):
109+
op._cache[op._key] = op
107110

108111
def device(self, result_index: int):
109112
if self._output_devices is None:
@@ -305,7 +308,9 @@ def output_device(x):
305308
assert output_device(self._results[i]) == d
306309

307310
ctx.cache_results(self, self._results)
308-
self._return_op_to_cache() # the operator instance is ready for a new invocation
311+
if self._return_op_to_cache:
312+
self._return_op_to_cache() # the operator instance is ready for a new invocation
313+
self._operator = None
309314

310315
def values(self, ctx: Optional[_EvalContext] = None):
311316
"""

dali/python/nvidia/dali/experimental/dynamic/_ops.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,14 @@ def _input_device(
285285

286286
@classmethod
287287
def _process_params(cls, backend, op_device, batch_size, *raw_args, **raw_kwargs):
288+
"""
289+
Processes run-time parameters passed to the operator to ones that can be consumed DALI
290+
(Batch or Tensor).
291+
292+
This is a class method, as it doesn't require an operator instance - and this method
293+
is essential for proper operator instance caching, as input/argument metadata is a part
294+
of the operator cache key.
295+
"""
288296
is_batch = batch_size is not None
289297
if cls._has_random_state_arg:
290298
from . import random
@@ -348,12 +356,6 @@ def _pre_call(self, *inputs, **args):
348356
def _is_backend_initialized(self):
349357
return self._op_backend is not None
350358

351-
def _reset_backend(self):
352-
self._op_backend = None
353-
self._op_spec = None
354-
self._input_meta = []
355-
self._arg_meta = {}
356-
357359
def _init_spec(self, inputs, args):
358360
if self._op_spec is None:
359361
import nvidia.dali as dali

0 commit comments

Comments
 (0)