Skip to content

Commit ee0564b

Browse files
authored
Merge pull request #2734 from saitcakmak/fix-memory-leak-mean-cache
Fix memory leak in DefaultPredictionStrategy cache hooks (#2631)
2 parents 6da5b28 + 92e464c commit ee0564b

3 files changed

Lines changed: 24 additions & 16 deletions

File tree

gpytorch/models/exact_prediction_strategies.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22

3-
import functools
43
import string
54
import warnings
65

@@ -25,7 +24,7 @@
2524
from .. import settings
2625
from ..distributions import MultitaskMultivariateNormal
2726
from ..lazy import LazyEvaluatedKernelTensor
28-
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache
27+
from ..utils.memoize import add_to_cache, cached, pop_from_cache, register_cache_clear_hook
2928

3029

3130
def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood):
@@ -108,10 +107,7 @@ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root
108107
if settings.detach_test_caches.on():
109108
res = res.detach()
110109

111-
if res.grad_fn is not None:
112-
wrapper = functools.partial(clear_cache_hook, self)
113-
functools.update_wrapper(wrapper, clear_cache_hook)
114-
res.grad_fn.register_hook(wrapper)
110+
register_cache_clear_hook(res, self)
115111

116112
return res
117113

@@ -313,10 +309,7 @@ def _mean_cache(self, nan_policy: str) -> Tensor:
313309
if settings.detach_test_caches.on():
314310
mean_cache = mean_cache.detach()
315311

316-
if mean_cache.grad_fn is not None:
317-
wrapper = functools.partial(clear_cache_hook, self)
318-
functools.update_wrapper(wrapper, clear_cache_hook)
319-
mean_cache.grad_fn.register_hook(wrapper)
312+
register_cache_clear_hook(mean_cache, self)
320313

321314
return mean_cache
322315

gpytorch/utils/memoize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import functools
66
import pickle
7+
import weakref
78

89
from .errors import CachingError
910

@@ -46,6 +47,24 @@ def clear_cache_hook(module, *args, **kwargs):
4647
module._memoize_cache = {}
4748

4849

50+
def register_cache_clear_hook(tsr, module):
51+
"""Register a backward hook on tsr's grad_fn that clears module's cache.
52+
53+
Uses a weak reference to module to avoid creating an uncollectable
54+
reference cycle through the C++ grad_fn object (which Python's cycle
55+
GC cannot see through).
56+
"""
57+
if tsr.grad_fn is not None:
58+
weak_module = weakref.ref(module)
59+
60+
def hook(*args, **kwargs):
61+
obj = weak_module()
62+
if obj is not None:
63+
obj._memoize_cache = {}
64+
65+
tsr.grad_fn.register_hook(hook)
66+
67+
4968
def _cached(method=None, name=None):
5069
"""A decorator allowing for specifying the name of a cache, allowing it to be modified elsewhere.
5170
This variant honors the calling args to the decorated function.

gpytorch/variational/_variational_strategy.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import functools
65
from abc import ABC, abstractproperty
76
from copy import deepcopy
87

@@ -18,7 +17,7 @@
1817
from ..models import ApproximateGP, ExactGP
1918
from ..models.exact_prediction_strategies import DefaultPredictionStrategy
2019
from ..module import Module
21-
from ..utils.memoize import add_to_cache, cached, clear_cache_hook
20+
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, register_cache_clear_hook
2221
from . import _VariationalDistribution
2322

2423

@@ -42,10 +41,7 @@ def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
4241

4342

4443
def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor:
45-
if tsr.grad_fn is not None:
46-
wrapper = functools.partial(clear_cache_hook, pred_strat)
47-
functools.update_wrapper(wrapper, clear_cache_hook)
48-
tsr.grad_fn.register_hook(wrapper)
44+
register_cache_clear_hook(tsr, pred_strat)
4945
return tsr
5046

5147

0 commit comments

Comments
 (0)