Skip to content

Commit 34da7b6

Browse files
authored
Upgrade to use py3.10 features (#123)
1 parent bf1c913 commit 34da7b6

70 files changed

Lines changed: 727 additions & 690 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

linear_operator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
3+
24
from linear_operator import beta_features, operators, settings, utils
35
from linear_operator.functions import (
46
add_diagonal,

linear_operator/beta_features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
import warnings
45

linear_operator/functions/__init__.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Optional, Tuple, Union
5+
from typing import Any, TypeAlias
66

77
import torch
88

99
from linear_operator.functions._dsmm import DSMM
1010

11-
LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle
11+
LinearOperatorType: TypeAlias = Any # Want this to be "LinearOperator" but runtime type checker can't handle
1212

1313

14-
Anysor = Union[LinearOperatorType, torch.Tensor]
14+
Anysor: TypeAlias = LinearOperatorType | torch.Tensor
1515

1616

1717
def add_diagonal(input: Anysor, diag: torch.Tensor) -> LinearOperatorType:
@@ -47,9 +47,7 @@ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
4747
return input + diag
4848

4949

50-
def diagonalization(
51-
input: Anysor, method: Optional[str] = None
52-
) -> Tuple[torch.Tensor, Union[torch.Tensor, LinearOperatorType]]:
50+
def diagonalization(input: Anysor, method: str | None = None) -> tuple[torch.Tensor, torch.Tensor | LinearOperatorType]:
5351
r"""
5452
Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
5553
:math:`\mathbf A`.
@@ -67,7 +65,7 @@ def diagonalization(
6765

6866

6967
def dsmm(
70-
sparse_mat: Union[torch.sparse.HalfTensor, torch.sparse.FloatTensor, torch.sparse.DoubleTensor],
68+
sparse_mat: torch.sparse.HalfTensor | torch.sparse.FloatTensor | torch.sparse.DoubleTensor,
7169
dense_mat: torch.Tensor,
7270
) -> torch.Tensor:
7371
r"""
@@ -111,8 +109,8 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool =
111109

112110

113111
def inv_quad_logdet(
114-
input: Anysor, inv_quad_rhs: Optional[torch.Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True
115-
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
112+
input: Anysor, inv_quad_rhs: torch.Tensor | None = None, logdet: bool = False, reduce_inv_quad: bool = True
113+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
116114
r"""
117115
Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
118116
However, calling this method is far more efficient and stable than calling each method independently.
@@ -133,8 +131,8 @@ def inv_quad_logdet(
133131

134132

135133
def pivoted_cholesky(
136-
input: Anysor, rank: int, error_tol: Optional[float] = None, return_pivots: bool = False
137-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
134+
input: Anysor, rank: int, error_tol: float | None = None, return_pivots: bool = False
135+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
138136
r"""
139137
Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
140138
:math:`\mathbf L \mathbf L^\top = \mathbf A`.
@@ -161,7 +159,7 @@ def pivoted_cholesky(
161159
return to_linear_operator(input).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots)
162160

163161

164-
def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOperatorType:
162+
def root_decomposition(input: Anysor, method: str | None = None) -> LinearOperatorType:
165163
r"""
166164
Returns a (usually low-rank) root decomposition linear operator of the
167165
positive definite matrix (or batch of matrices) :math:`\mathbf A`.
@@ -180,9 +178,9 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe
180178

181179
def root_inv_decomposition(
182180
input: Anysor,
183-
initial_vectors: Optional[torch.Tensor] = None,
184-
test_vectors: Optional[torch.Tensor] = None,
185-
method: Optional[str] = None,
181+
initial_vectors: torch.Tensor | None = None,
182+
test_vectors: torch.Tensor | None = None,
183+
method: str | None = None,
186184
) -> LinearOperatorType:
187185
r"""
188186
Returns a (usually low-rank) inverse root decomposition linear operator
@@ -206,7 +204,7 @@ def root_inv_decomposition(
206204
)
207205

208206

209-
def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) -> torch.Tensor:
207+
def solve(input: Anysor, rhs: torch.Tensor, lhs: torch.Tensor | None = None) -> torch.Tensor:
210208
r"""
211209
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
212210
computes a linear solve with right hand side :math:`\mathbf R`:
@@ -241,8 +239,8 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None)
241239

242240

243241
def sqrt_inv_matmul(
244-
input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None
245-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
242+
input: Anysor, rhs: torch.Tensor, lhs: torch.Tensor | None = None
243+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
246244
r"""
247245
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
248246
and a right hand size :math:`\mathbf R`,

linear_operator/functions/_diagonalization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
import torch
45
from torch.autograd import Function

linear_operator/functions/_dsmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
from torch.autograd import Function
45

linear_operator/functions/_inv_quad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
import torch
45
from torch.autograd import Function

linear_operator/functions/_inv_quad_logdet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
import warnings
45

linear_operator/functions/_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
from torch.autograd import Function
45

linear_operator/functions/_pivoted_cholesky.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
import torch
45
from torch.autograd import Function

linear_operator/functions/_root_decomposition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
23

34
import torch
45
from torch.autograd import Function

0 commit comments

Comments
 (0)