|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 |
|
3 | | -import functools |
4 | 3 | import string |
5 | 4 | import warnings |
6 | 5 |
|
|
25 | 24 | from .. import settings |
26 | 25 | from ..distributions import MultitaskMultivariateNormal |
27 | 26 | from ..lazy import LazyEvaluatedKernelTensor |
28 | | -from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache |
| 27 | +from ..utils.memoize import add_to_cache, cached, pop_from_cache, register_cache_clear_hook |
29 | 28 |
|
30 | 29 |
|
31 | 30 | def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood): |
@@ -108,10 +107,7 @@ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root |
108 | 107 | if settings.detach_test_caches.on(): |
109 | 108 | res = res.detach() |
110 | 109 |
|
111 | | - if res.grad_fn is not None: |
112 | | - wrapper = functools.partial(clear_cache_hook, self) |
113 | | - functools.update_wrapper(wrapper, clear_cache_hook) |
114 | | - res.grad_fn.register_hook(wrapper) |
| 110 | + register_cache_clear_hook(res, self) |
115 | 111 |
|
116 | 112 | return res |
117 | 113 |
|
@@ -313,10 +309,7 @@ def _mean_cache(self, nan_policy: str) -> Tensor: |
313 | 309 | if settings.detach_test_caches.on(): |
314 | 310 | mean_cache = mean_cache.detach() |
315 | 311 |
|
316 | | - if mean_cache.grad_fn is not None: |
317 | | - wrapper = functools.partial(clear_cache_hook, self) |
318 | | - functools.update_wrapper(wrapper, clear_cache_hook) |
319 | | - mean_cache.grad_fn.register_hook(wrapper) |
| 312 | + register_cache_clear_hook(mean_cache, self) |
320 | 313 |
|
321 | 314 | return mean_cache |
322 | 315 |
|
|
0 commit comments