22
33from __future__ import annotations
44
5- from typing import Any , Optional , Tuple , Union
5+ from typing import Any , TypeAlias
66
77import torch
88
99from linear_operator .functions ._dsmm import DSMM
1010
11- LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle
11+ LinearOperatorType : TypeAlias = Any # Want this to be "LinearOperator" but runtime type checker can't handle
1212
1313
14- Anysor = Union [ LinearOperatorType , torch .Tensor ]
14+ Anysor : TypeAlias = LinearOperatorType | torch .Tensor
1515
1616
1717def add_diagonal (input : Anysor , diag : torch .Tensor ) -> LinearOperatorType :
@@ -47,9 +47,7 @@ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
4747 return input + diag
4848
4949
50- def diagonalization (
51- input : Anysor , method : Optional [str ] = None
52- ) -> Tuple [torch .Tensor , Union [torch .Tensor , LinearOperatorType ]]:
50+ def diagonalization (input : Anysor , method : str | None = None ) -> tuple [torch .Tensor , torch .Tensor | LinearOperatorType ]:
5351 r"""
5452 Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
5553 :math:`\mathbf A`.
@@ -67,7 +65,7 @@ def diagonalization(
6765
6866
6967def dsmm (
70- sparse_mat : Union [ torch .sparse .HalfTensor , torch .sparse .FloatTensor , torch .sparse .DoubleTensor ] ,
68+ sparse_mat : torch .sparse .HalfTensor | torch .sparse .FloatTensor | torch .sparse .DoubleTensor ,
7169 dense_mat : torch .Tensor ,
7270) -> torch .Tensor :
7371 r"""
@@ -111,8 +109,8 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool =
111109
112110
113111def inv_quad_logdet (
114- input : Anysor , inv_quad_rhs : Optional [ torch .Tensor ] = None , logdet : bool = False , reduce_inv_quad : bool = True
115- ) -> Tuple [ Optional [ torch .Tensor ], Optional [ torch .Tensor ] ]:
112+ input : Anysor , inv_quad_rhs : torch .Tensor | None = None , logdet : bool = False , reduce_inv_quad : bool = True
113+ ) -> tuple [ torch .Tensor | None , torch .Tensor | None ]:
116114 r"""
117115 Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
118116 However, calling this method is far more efficient and stable than calling each method independently.
@@ -133,8 +131,8 @@ def inv_quad_logdet(
133131
134132
135133def pivoted_cholesky (
136- input : Anysor , rank : int , error_tol : Optional [ float ] = None , return_pivots : bool = False
137- ) -> Union [ torch .Tensor , Tuple [torch .Tensor , torch .Tensor ] ]:
134+ input : Anysor , rank : int , error_tol : float | None = None , return_pivots : bool = False
135+ ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
138136 r"""
139137 Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
140138 :math:`\mathbf L \mathbf L^\top = \mathbf A`.
@@ -161,7 +159,7 @@ def pivoted_cholesky(
161159 return to_linear_operator (input ).pivoted_cholesky (rank = rank , error_tol = error_tol , return_pivots = return_pivots )
162160
163161
164- def root_decomposition (input : Anysor , method : Optional [ str ] = None ) -> LinearOperatorType :
162+ def root_decomposition (input : Anysor , method : str | None = None ) -> LinearOperatorType :
165163 r"""
166164 Returns a (usually low-rank) root decomposition linear operator of the
167165 positive definite matrix (or batch of matrices) :math:`\mathbf A`.
@@ -180,9 +178,9 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe
180178
181179def root_inv_decomposition (
182180 input : Anysor ,
183- initial_vectors : Optional [ torch .Tensor ] = None ,
184- test_vectors : Optional [ torch .Tensor ] = None ,
185- method : Optional [ str ] = None ,
181+ initial_vectors : torch .Tensor | None = None ,
182+ test_vectors : torch .Tensor | None = None ,
183+ method : str | None = None ,
186184) -> LinearOperatorType :
187185 r"""
188186 Returns a (usually low-rank) inverse root decomposition linear operator
@@ -206,7 +204,7 @@ def root_inv_decomposition(
206204 )
207205
208206
209- def solve (input : Anysor , rhs : torch .Tensor , lhs : Optional [ torch .Tensor ] = None ) -> torch .Tensor :
207+ def solve (input : Anysor , rhs : torch .Tensor , lhs : torch .Tensor | None = None ) -> torch .Tensor :
210208 r"""
211209 Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
212210 computes a linear solve with right hand side :math:`\mathbf R`:
@@ -241,8 +239,8 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None)
241239
242240
243241def sqrt_inv_matmul (
244- input : Anysor , rhs : torch .Tensor , lhs : Optional [ torch .Tensor ] = None
245- ) -> Union [ torch .Tensor , Tuple [torch .Tensor , torch .Tensor ] ]:
242+ input : Anysor , rhs : torch .Tensor , lhs : torch .Tensor | None = None
243+ ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
246244 r"""
247245 Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
248246 and a right hand size :math:`\mathbf R`,
0 commit comments