Skip to content

Commit 386cff7

Browse files
authored
Merge pull request #128 from cornellius-gp/fix-deprecated-apis-and-bugs
Fix deprecated APIs, latent bugs, and code quality issues
2 parents b9e8650 + 504c1a4 commit 386cff7

12 files changed

Lines changed: 25 additions & 43 deletions

linear_operator/operators/_linear_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _implements_second_arg(torch_function: Callable) -> Callable:
8080
where the first argument of the function is a torch.Tensor and the
8181
second argument is a LinearOperator
8282
83-
Examples of this include :meth:`torch.cholesky_solve`, `torch.linalg.solve`, or `torch.matmul`.
83+
Examples of this include `torch.linalg.solve` or `torch.matmul`.
8484
"""
8585

8686
@functools.wraps(torch_function)

linear_operator/operators/batch_repeat_linear_operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def _move_repeat_batches_back(self, batch_matrix, output_shape):
139139
So that the tensor is now rb x m x n.
140140
"""
141141
if hasattr(self, "_batch_move_memo"):
142-
padded_base_batch_shape, batch_repeat = self.__batch_move_memo
143-
del self.__batch_move_memo
142+
_, padded_base_batch_shape, batch_repeat = self._batch_move_memo
143+
del self._batch_move_memo
144144
else:
145145
padding_dims = torch.Size(tuple(1 for _ in range(max(len(output_shape) - self.base_linear_op.dim(), 0))))
146146
padded_base_batch_shape = padding_dims + self.base_linear_op.batch_shape
@@ -188,7 +188,7 @@ def _move_repeat_batches_to_columns(self, batch_matrix, output_shape):
188188
batch_matrix = batch_matrix.permute(*batch_dims, -2, -1, *repeat_dims).contiguous()
189189
batch_matrix = batch_matrix.view(*self.base_linear_op.batch_shape, output_shape[-2], -1)
190190

191-
self.__batch_move_memo = output_shape, padded_base_batch_shape, batch_repeat
191+
self._batch_move_memo = output_shape, padded_base_batch_shape, batch_repeat
192192
return batch_matrix
193193

194194
def _permute_batch(self, *dims: int) -> LinearOperator:

linear_operator/operators/cat_linear_operator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
1212

1313
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
14-
from linear_operator.utils.deprecation import bool_compat
1514
from linear_operator.utils.generic import _to_helper
1615
from linear_operator.utils.getitem import _noop_index
1716

@@ -188,7 +187,7 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice
188187

189188
# Find out for which indices we switch to different tensors
190189
target_tensors = self.idx_to_tensor_idx[cat_dim_indices]
191-
does_switch_tensor = torch.ones(target_tensors.numel() + 1, dtype=bool_compat, device=self.device)
190+
does_switch_tensor = torch.ones(target_tensors.numel() + 1, dtype=torch.bool, device=self.device)
192191
torch.ne(target_tensors[:-1], target_tensors[1:], out=does_switch_tensor[1:-1])
193192

194193
# Get the LinearOperators that will comprise the new LinearOperator
@@ -258,7 +257,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I
258257

259258
# Find out for which indices we switch to different tensors
260259
target_tensors = self.idx_to_tensor_idx[cat_dim_indices]
261-
does_switch_tensor = torch.ones(target_tensors.numel() + 1, dtype=bool_compat, device=self.device)
260+
does_switch_tensor = torch.ones(target_tensors.numel() + 1, dtype=torch.bool, device=self.device)
262261
torch.ne(target_tensors[:-1], target_tensors[1:], out=does_switch_tensor[1:-1])
263262

264263
# Get the LinearOperators that will comprise the new LinearOperator
@@ -294,7 +293,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I
294293

295294
else:
296295
raise RuntimeError(
297-
"Unexpected index type {cat_dim_indices.__class__.__name__}. This is a bug in LinearOperator."
296+
f"Unexpected index type {cat_dim_indices.__class__.__name__}. This is a bug in LinearOperator."
298297
)
299298

300299
# Process the list

linear_operator/operators/kronecker_product_linear_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _matmul(linear_ops, kp_shape, rhs):
3535
output_shape = _matmul_broadcast_shape(kp_shape, rhs.shape)
3636
output_batch_shape = output_shape[:-2]
3737

38-
res = rhs.contiguous().expand(*output_batch_shape, *rhs.shape[-2:])
38+
res = rhs.expand(*output_batch_shape, *rhs.shape[-2:])
3939
num_cols = rhs.size(-1)
4040
for linear_op in linear_ops:
4141
res = res.view(*output_batch_shape, linear_op.size(-1), -1)
@@ -50,7 +50,7 @@ def _t_matmul(linear_ops, kp_shape, rhs):
5050
output_shape = _matmul_broadcast_shape(kp_t_shape, rhs.shape)
5151
output_batch_shape = torch.Size(output_shape[:-2])
5252

53-
res = rhs.contiguous().expand(*output_batch_shape, *rhs.shape[-2:])
53+
res = rhs.expand(*output_batch_shape, *rhs.shape[-2:])
5454
num_cols = rhs.size(-1)
5555
for linear_op in linear_ops:
5656
res = res.view(*output_batch_shape, linear_op.size(-2), -1)

linear_operator/operators/low_rank_root_added_diag_linear_operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def _solve(
7777
chol_cap_mat = self.chol_cap_mat
7878

7979
res = V.matmul(A_inv.matmul(rhs))
80-
res = torch.cholesky_solve(res, chol_cap_mat)
80+
res = torch.linalg.solve_triangular(
81+
chol_cap_mat.mT, torch.linalg.solve_triangular(chol_cap_mat, res, upper=False), upper=True
82+
)
8183
res = A_inv.matmul(U.matmul(res))
8284

8385
solve = A_inv.matmul(rhs) - res

linear_operator/operators/sum_linear_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, *linear_ops, **kwargs):
2828
def _diagonal(
2929
self: LinearOperator, # shape: (..., M, N)
3030
) -> torch.Tensor: # shape: (..., N)
31-
return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)
31+
return sum(linear_op._diagonal() for linear_op in self.linear_ops)
3232

3333
def _expand_batch(
3434
self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N)

linear_operator/operators/triangular_linear_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def inv_quad_logdet(
194194
inv_quad_term = (inv_quad_rhs * self.solve(inv_quad_rhs)).sum(dim=-2)
195195
if logdet:
196196
diag = self._diagonal()
197-
logdet_term = self._diagonal().abs().log().sum(-1)
197+
logdet_term = diag.abs().log().sum(-1)
198198
if torch.sign(diag).prod(-1) < 0:
199199
logdet_term = torch.full_like(logdet_term, float("nan"))
200200
else:

linear_operator/test/linear_operator_test_case.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,6 @@ def test_add_jitter(self):
666666
self.assertAllClose(res, actual)
667667

668668
def test_add_low_rank(self):
669-
linear_op = self.create_linear_op()
670669
linear_op = self.create_linear_op()
671670
evaluated = self.evaluate_linear_op(linear_op)
672671
new_rows = torch.randn(*linear_op.shape[:-1], 3)

linear_operator/utils/deprecation.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33

44
import functools
55
import warnings
6-
from unittest.mock import MagicMock
7-
8-
import torch
9-
10-
# TODO: Use bool instead of uint8 dtype once pytorch #21113 is in stable release
11-
if isinstance(torch, MagicMock):
12-
bool_compat = torch.uint8
13-
else:
14-
bool_compat = (torch.ones(1) > 0).dtype
156

167

178
class DeprecationError(Exception):

linear_operator/utils/linear_cg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77

88
from linear_operator import settings
9-
from linear_operator.utils.deprecation import bool_compat
109
from linear_operator.utils.warnings import NumericalWarning
1110

1211

@@ -219,7 +218,7 @@ def linear_cg(
219218
mul_storage = torch.empty_like(residual)
220219
alpha = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=residual.dtype, device=residual.device)
221220
beta = torch.empty_like(alpha)
222-
is_zero = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=bool_compat, device=residual.device)
221+
is_zero = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=torch.bool, device=residual.device)
223222

224223
# Define tridiagonal matrices, if applicable
225224
if n_tridiag:
@@ -231,7 +230,7 @@ def linear_cg(
231230
dtype=alpha.dtype,
232231
device=alpha.device,
233232
)
234-
alpha_tridiag_is_zero = torch.empty(*batch_shape, n_tridiag, dtype=bool_compat, device=t_mat.device)
233+
alpha_tridiag_is_zero = torch.empty(*batch_shape, n_tridiag, dtype=torch.bool, device=t_mat.device)
235234
alpha_reciprocal = torch.empty(*batch_shape, n_tridiag, dtype=t_mat.dtype, device=t_mat.device)
236235
prev_alpha_reciprocal = torch.empty_like(alpha_reciprocal)
237236
prev_beta = torch.empty_like(alpha_reciprocal)

0 commit comments

Comments
 (0)