Skip to content

Commit b0a4312

Browse files
kayweenclaude
andauthored
Make linear operators pickleable (#124)
* use qualified names as cache keys * test pickle with populated caches * make `KernelLinearOperator` pickle-able Replace lambda in defaultdict with a module-level function so that num_nonbatch_dimensions can be pickled. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 34da7b6 commit b0a4312

3 files changed

Lines changed: 19 additions & 4 deletions

File tree

linear_operator/operators/kernel_linear_operator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from linear_operator.utils.memoize import cached
1515

1616

17+
def _default_num_nonbatch_dimensions():
18+
return 2
19+
20+
1721
def _x_getitem(x, batch_indices, data_index):
1822
"""
1923
Helper function to compute x[*batch_indices, data_index, :] in an efficient way.
@@ -142,9 +146,9 @@ def __init__(
142146
):
143147
# Change num_nonbatch_dimensions into a default dict
144148
if num_nonbatch_dimensions is None:
145-
num_nonbatch_dimensions = defaultdict(lambda: 2)
149+
num_nonbatch_dimensions = defaultdict(_default_num_nonbatch_dimensions)
146150
else:
147-
num_nonbatch_dimensions = defaultdict(lambda: 2, **num_nonbatch_dimensions)
151+
num_nonbatch_dimensions = defaultdict(_default_num_nonbatch_dimensions, **num_nonbatch_dimensions)
148152

149153
# Divide params into tensors and non-tensors
150154
tensor_params = dict()

linear_operator/test/linear_operator_test_case.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import logging
55
import math
6+
import pickle
67
import traceback
78
from abc import abstractmethod
89
from itertools import combinations, product
@@ -955,6 +956,16 @@ def test_logdet(self):
955956
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
956957
self.assertAllClose(arg.grad, arg_copy.grad, **tolerances)
957958

959+
def test_pickle(self):
960+
linear_op = self.create_linear_op()
961+
962+
# Make sure that pickle works with populated caches
963+
_ = linear_op.to_dense()
964+
965+
pickled = pickle.dumps(linear_op)
966+
unpickled = pickle.loads(pickled)
967+
self.assertAllClose(unpickled.to_dense(), linear_op.to_dense())
968+
958969
def test_prod(self):
959970
with linear_operator.settings.fast_computations(covar_root_decomposition=False):
960971
linear_op = self.create_linear_op()

linear_operator/utils/memoize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _cached(method=None, name=None):
5454

5555
@functools.wraps(method)
5656
def g(self, *args, **kwargs):
57-
cache_name = name if name is not None else method
57+
cache_name = name if name is not None else method.__qualname__
5858
kwargs_pkl = pickle.dumps(kwargs)
5959
if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
6060
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
@@ -72,7 +72,7 @@ def _cached_ignore_args(method=None, name=None):
7272

7373
@functools.wraps(method)
7474
def g(self, *args, **kwargs):
75-
cache_name = name if name is not None else method
75+
cache_name = name if name is not None else method.__qualname__
7676
if not _is_in_cache_ignore_args(self, cache_name):
7777
return _add_to_cache_ignore_args(self, cache_name, method(self, *args, **kwargs))
7878
return _get_from_cache_ignore_args(self, cache_name)

0 commit comments

Comments
 (0)