33
44from torch import Tensor
55
6- from torchjd ._linalg import PSDGeneralizedMatrix , PSDMatrix
6+ from torchjd ._linalg import PSDMatrix , PSDTensor
77
88
9- def flatten (gramian : PSDGeneralizedMatrix ) -> PSDMatrix :
9+ def flatten (gramian : PSDTensor ) -> PSDMatrix :
1010 """
1111 Flattens a generalized Gramian into a square matrix. The first half of the dimensions are
1212 flattened into the first dimension, and the second half are flattened into the second.
@@ -24,7 +24,7 @@ def flatten(gramian: PSDGeneralizedMatrix) -> PSDMatrix:
2424 return cast (PSDMatrix , square_gramian )
2525
2626
27- def reshape (gramian : PSDGeneralizedMatrix , half_shape : list [int ]) -> PSDGeneralizedMatrix :
27+ def reshape (gramian : PSDTensor , half_shape : list [int ]) -> PSDTensor :
2828 """
2929 Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions
3030 must be done from the left, while the reshape of the second half must be done from the right.
@@ -42,7 +42,7 @@ def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGenerali
4242 # [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4]
4343
4444 result = _revert_last_dims (_revert_last_dims (gramian ).reshape (half_shape + half_shape ))
45- return cast (PSDGeneralizedMatrix , result )
45+ return cast (PSDTensor , result )
4646
4747
4848def _revert_last_dims (t : Tensor ) -> Tensor :
@@ -53,9 +53,7 @@ def _revert_last_dims(t: Tensor) -> Tensor:
5353 return t .movedim (last_dims , last_dims [::- 1 ])
5454
5555
56- def movedim (
57- gramian : PSDGeneralizedMatrix , half_source : list [int ], half_destination : list [int ]
58- ) -> PSDGeneralizedMatrix :
56+ def movedim (gramian : PSDTensor , half_source : list [int ], half_destination : list [int ]) -> PSDTensor :
5957 """
6058 Moves the dimensions of a Gramian from some source dimensions to destination dimensions. This
6159 must be done simultaneously on the first half of the dimensions and on the second half of the
@@ -86,4 +84,4 @@ def movedim(
8684 source = half_source_ + [last_dim - i for i in half_source_ ]
8785 destination = half_destination_ + [last_dim - i for i in half_destination_ ]
8886 moved_gramian = gramian .movedim (source , destination )
89- return cast (PSDGeneralizedMatrix , moved_gramian )
87+ return cast (PSDTensor , moved_gramian )
0 commit comments