Skip to content

Commit e347075

Browse files
Apply suggestions from code review
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
1 parent 2b94d78 commit e347075

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/torchjd/_linalg/_matrix.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44

55

66
class GeneralizedMatrix(Tensor):
7-
pass
7+
"""Tensor with a least 1 dimension."""
88

99

1010
class Matrix(GeneralizedMatrix):
11-
pass
11+
"""Tensor with exactly 2 dimensions."""
1212

1313

1414
class PSDGeneralizedMatrix(Tensor):
15-
pass
15+
"""
16+
Tensor representing a quadratic form. The first half of its dimensions matches the reversed
17+
second half of its dimensions (e.g. shape=[4, 3, 3, 4]), and its reshaping into a matrix should
18+
be positive semi-definite.
19+
"""
1620

1721

1822
class PSDMatrix(PSDGeneralizedMatrix, Matrix):
19-
pass
23+
"""Positive semi-definite matrix."""
2024

2125

2226
def is_generalized_matrix(t: Tensor) -> TypeGuard[GeneralizedMatrix]:
@@ -31,10 +35,10 @@ def is_psd_generalized_matrix(t: Tensor) -> TypeGuard[PSDGeneralizedMatrix]:
3135
half_dim = t.ndim // 2
3236
return t.ndim % 2 == 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]
3337
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
34-
# every function that use this TypeGuard by using `assert_psd_generalized_matrix`.
38+
# every function that uses this TypeGuard by using `assert_psd_generalized_matrix`.
3539

3640

3741
def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
3842
return t.ndim == 2 and t.shape[0] == t.shape[1]
3943
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
40-
# every function that use this TypeGuard, by using `assert_psd_matrix`.
44+
# every function that uses this TypeGuard, by using `assert_psd_matrix`.

src/torchjd/autogram/_gramian_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGenerali
1616
`shape + shape[::-1]`.
1717
"""
1818

19-
# Example 1: `gramian` of shape [4, 3, 2, 2, 3, 4] and `shape` of [8, 3]:
19+
# Example 1: `gramian` of shape [4, 3, 2, 2, 3, 4] and `half_shape` of [8, 3]:
2020
# [4, 3, 2, 2, 3, 4] -(movedim)-> [4, 3, 2, 4, 3, 2] -(reshape)-> [8, 3, 8, 3] -(movedim)->
2121
# [8, 3, 3, 8]
2222
#
23-
# Example 2: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]:
23+
# Example 2: `gramian` of shape [24, 24] and `half_shape` of [4, 3, 2]:
2424
# [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4]
2525

2626
reshaped_gramian = _revert_last_dims(
@@ -31,8 +31,8 @@ def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGenerali
3131

3232
def flatten(gramian: PSDGeneralizedMatrix) -> PSDMatrix:
3333
k = gramian.ndim // 2
34-
shape = gramian.shape[:k]
35-
m = prod(shape)
34+
half_shape = gramian.shape[:k]
35+
m = prod(half_shape)
3636
square_gramian = reshape(gramian, [m])
3737
return cast(PSDMatrix, square_gramian)
3838

0 commit comments

Comments
 (0)