Skip to content

Commit 55bc6f8

Browse files
committed
Rename matrix to t in compute_gramian
1 parent 47bf743 commit 55bc6f8

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/torchjd/_linalg/_gramian.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88

99
@overload
10-
def compute_gramian(matrix: Tensor) -> PSDMatrix:
10+
def compute_gramian(t: Tensor) -> PSDMatrix:
1111
pass
1212

1313

1414
@overload
15-
def compute_gramian(matrix: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
15+
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
1616
pass
1717

1818

19-
def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDTensor:
19+
def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
2020
"""
2121
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.
2222
@@ -25,11 +25,11 @@ def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDTensor:
2525
first dimension).
2626
"""
2727

28-
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + matrix.ndim
29-
indices_source = list(range(matrix.ndim - contracted_dims))
30-
indices_dest = list(range(matrix.ndim - 1, contracted_dims - 1, -1))
31-
transposed_matrix = matrix.movedim(indices_source, indices_dest)
32-
gramian = torch.tensordot(matrix, transposed_matrix, dims=contracted_dims)
28+
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
29+
indices_source = list(range(t.ndim - contracted_dims))
30+
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
31+
transposed = t.movedim(indices_source, indices_dest)
32+
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
3333
return cast(PSDTensor, gramian)
3434

3535

0 commit comments

Comments
 (0)