77
88
99@overload
10- def compute_gramian (matrix : Tensor ) -> PSDMatrix :
10+ def compute_gramian (t : Tensor ) -> PSDMatrix :
1111 pass
1212
1313
1414@overload
15- def compute_gramian (matrix : Tensor , contracted_dims : Literal [- 1 ]) -> PSDMatrix :
15+ def compute_gramian (t : Tensor , contracted_dims : Literal [- 1 ]) -> PSDMatrix :
1616 pass
1717
1818
19- def compute_gramian (matrix : Tensor , contracted_dims : int = - 1 ) -> PSDTensor :
19+ def compute_gramian (t : Tensor , contracted_dims : int = - 1 ) -> PSDTensor :
2020 """
2121 Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.
2222
@@ -25,11 +25,11 @@ def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDTensor:
2525 first dimension).
2626 """
2727
28- contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + matrix .ndim
29- indices_source = list (range (matrix .ndim - contracted_dims ))
30- indices_dest = list (range (matrix .ndim - 1 , contracted_dims - 1 , - 1 ))
31- transposed_matrix = matrix .movedim (indices_source , indices_dest )
32- gramian = torch .tensordot (matrix , transposed_matrix , dims = contracted_dims )
28+ contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t .ndim
29+ indices_source = list (range (t .ndim - contracted_dims ))
30+ indices_dest = list (range (t .ndim - 1 , contracted_dims - 1 , - 1 ))
31+ transposed = t .movedim (indices_source , indices_dest )
32+ gramian = torch .tensordot (t , transposed , dims = contracted_dims )
3333 return cast (PSDTensor , gramian )
3434
3535
0 commit comments