Skip to content

Commit 6eddc91

Browse files
authored
refactor(linalg): Add missing overload to compute_gramian (#540)
1 parent d4e9957 commit 6eddc91

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/torchjd/_linalg/_gramian.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix:
2121
pass
2222

2323

24+
@overload
25+
def compute_gramian(t: Tensor, contracted_dims: int) -> PSDTensor:
26+
pass
27+
28+
2429
def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
2530
"""
2631
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.

0 commit comments

Comments
 (0)