Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9fec107
refactor(linalg): Add `PSDQuadraticForm` and `GeneralizedMatrix`.
PierreQuinton Jan 19, 2026
acf8e58
Merge branch 'main' into add-generalized-matrix-psd-matrix
PierreQuinton Jan 19, 2026
a744fa2
Sort items of `__all__` of `_linalg.__init__`
PierreQuinton Jan 19, 2026
23de54d
one line
PierreQuinton Jan 19, 2026
2bd603e
fix `is_psd_quadratic_form`
PierreQuinton Jan 19, 2026
d6f8375
remove outdated comment
PierreQuinton Jan 19, 2026
24d24bb
Add `assert_psd_quadratic_form` and TODOs for where to test it. I als…
PierreQuinton Jan 19, 2026
242cb55
fix is_psd_quadratic_form
PierreQuinton Jan 20, 2026
72a9a5f
Rename `PSDQuadraticForm` to `PSDGeneralizedMatrix`
PierreQuinton Jan 20, 2026
09df593
fix type of weighting in Flattening
PierreQuinton Jan 20, 2026
0a1d45c
Add parametrization of zero matrix for test_gramian_is_psd
PierreQuinton Jan 20, 2026
6f63182
Add test of the PSD property for functions in aggregation/_utils/gramian
PierreQuinton Jan 20, 2026
5a42ecd
rename test of equivariance accordingly
PierreQuinton Jan 20, 2026
0497f3a
Rename functions in `autogram/_gramian_utils` so that they don't incl…
PierreQuinton Jan 20, 2026
48df0a8
Test the PSD property on outputs of functions in `autogram/_gramian_u…
PierreQuinton Jan 20, 2026
40977f3
Remove internal checks of shapes of matrices
PierreQuinton Jan 20, 2026
92b975b
Remove uninformative shadowing of assertion error in assert_psd_*
PierreQuinton Jan 20, 2026
bda0a5f
Factorize `compute_gramian` from `forward_backward` by making the one…
PierreQuinton Jan 20, 2026
97bcf42
Revert "Factorize `compute_gramian` from `forward_backward` by making…
PierreQuinton Jan 20, 2026
03aebae
Generalizes `compute_gramian` to take a `GeneralizedMatrix` instead.
PierreQuinton Jan 20, 2026
ee54c09
Move `aggregation/_utils/gramian.py` to `_linalg/gramian.py`
PierreQuinton Jan 20, 2026
2b94d78
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey Jan 20, 2026
e347075
Apply suggestions from code review
PierreQuinton Jan 21, 2026
3d9742c
Remove outdated comments
PierreQuinton Jan 21, 2026
f2d0d1b
Improve style
PierreQuinton Jan 21, 2026
5eafa74
Improve typing of `forward_backward.compute_gramian`
PierreQuinton Jan 21, 2026
d60e9fa
improve asserts
PierreQuinton Jan 21, 2026
f4d611b
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey Jan 21, 2026
57af9f1
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey Jan 21, 2026
a793693
Can parametrize number of dimensions to contract in `compute_gramian`
PierreQuinton Jan 22, 2026
7da352c
Remove GeneralizedMatrix
ValerianRey Jan 23, 2026
994932b
Rename PSDGeneralizedMatrix to PSDTensor
ValerianRey Jan 23, 2026
ab809c6
Add comment about using classes
ValerianRey Jan 23, 2026
47bf743
Remove useless overload of compute_gramian
ValerianRey Jan 23, 2026
55bc6f8
Rename matrix to t in compute_gramian
ValerianRey Jan 23, 2026
a80f3f6
Add overload for compute_gramian when t is matrix and contracted_dims…
ValerianRey Jan 23, 2026
09393cc
Stop expecting coverage for overload functions
ValerianRey Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
from .gramian import compute_gramian
from .matrix import Matrix, PSDMatrix
from ._gramian import compute_gramian
from ._matrix import (
GeneralizedMatrix,
Matrix,
PSDMatrix,
PSDQuadraticForm,
is_generalized_matrix,
is_matrix,
is_psd_matrix,
is_psd_quadratic_form,
)

__all__ = ["compute_gramian", "Matrix", "PSDMatrix"]
__all__ = [
"compute_gramian",
"GeneralizedMatrix",
"Matrix",
"PSDMatrix",
"PSDQuadraticForm",
"is_generalized_matrix",
"is_matrix",
"is_psd_matrix",
"is_psd_quadratic_form",
]
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .matrix import Matrix, PSDMatrix
from ._matrix import Matrix, PSDMatrix, is_psd_matrix


def compute_gramian(matrix: Matrix) -> PSDMatrix:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T
gramian = matrix @ matrix.T
assert is_psd_matrix(gramian)
return gramian
42 changes: 42 additions & 0 deletions src/torchjd/_linalg/_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import TypeGuard

from torch import Tensor


class GeneralizedMatrix(Tensor):
pass
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated


class Matrix(GeneralizedMatrix):
pass
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated


class PSDQuadraticForm(Tensor):
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
pass
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated


class PSDMatrix(PSDQuadraticForm, Matrix):
pass
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated


def is_generalized_matrix(t: Tensor) -> TypeGuard[GeneralizedMatrix]:
return t.ndim >= 1


def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
return t.ndim == 2


def is_psd_quadratic_form(t: Tensor) -> TypeGuard[PSDQuadraticForm]:
half_dim = t.ndim // 2
return not t.ndim % 2 != 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
# every function that use this TypeGuard.
# TODO: Say with what assert we check that

Check failure on line 35 in src/torchjd/_linalg/_matrix.py

View workflow job for this annotation

GitHub Actions / check-todos

TODO found at src/torchjd/_linalg/_matrix.py:35 - must be resolved before merge: # TODO: Say with what assert we check that


def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
return t.ndim == 2 and t.shape[0] == t.shape[1]
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
# every function that use this TypeGuard.
# TODO: Say with what assert we check that

Check failure on line 42 in src/torchjd/_linalg/_matrix.py

View workflow job for this annotation

GitHub Actions / check-todos

TODO found at src/torchjd/_linalg/_matrix.py:42 - must be resolved before merge: # TODO: Say with what assert we check that
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
6 changes: 0 additions & 6 deletions src/torchjd/_linalg/matrix.py

This file was deleted.

16 changes: 8 additions & 8 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor, nn

from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix

from ._weighting_bases import Weighting

Expand All @@ -18,20 +18,21 @@ def __init__(self):

@staticmethod
def _check_is_matrix(matrix: Tensor) -> None:
if len(matrix.shape) != 2:
if not is_matrix(matrix):
raise ValueError(
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
f"{matrix.shape}`."
)

@abstractmethod
def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Matrix) -> Tensor:
"""Computes the aggregation from the input matrix."""

# Override to make type hints and documentation more specific
# Override to make type hints and documentation more specific, note that `Matrix` type isn't
# public
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
def __call__(self, matrix: Tensor) -> Tensor:
"""Computes the aggregation from the input matrix and applies all registered hooks."""

Aggregator._check_is_matrix(matrix)
return super().__call__(matrix)

def __repr__(self) -> str:
Expand All @@ -54,7 +55,7 @@ def __init__(self, weighting: Weighting[Matrix]):
self.weighting = weighting

@staticmethod
def combine(matrix: Tensor, weights: Tensor) -> Tensor:
def combine(matrix: Matrix, weights: Tensor) -> Tensor:
"""
Aggregates a matrix by making a linear combination of its rows, using the provided vector of
weights.
Expand All @@ -63,8 +64,7 @@ def combine(matrix: Tensor, weights: Tensor) -> Tensor:
vector = weights @ matrix
return vector

def forward(self, matrix: Tensor) -> Tensor:
self._check_is_matrix(matrix)
def forward(self, matrix: Matrix) -> Tensor:
weights = self.weighting(matrix)
vector = self.combine(matrix, weights)
return vector
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import Aggregator
from ._sum import SumWeighting
from ._utils.non_differentiable import raise_non_differentiable_error
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None):
# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Matrix) -> Tensor:
weights = self.weighting(matrix)
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
best_direction = torch.linalg.pinv(units) @ weights
Expand Down
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor

from torchjd._linalg.matrix import PSDMatrix
from torchjd._linalg import PSDMatrix, PSDQuadraticForm, is_psd_matrix
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
from torchjd.autogram._gramian_utils import reshape_gramian

Expand All @@ -26,11 +26,12 @@ def __init__(self, weighting: Weighting[PSDMatrix]):
super().__init__()
self.weighting = weighting

def forward(self, generalized_gramian: Tensor) -> Tensor:
def forward(self, generalized_gramian: PSDQuadraticForm) -> Tensor:
k = generalized_gramian.ndim // 2
shape = generalized_gramian.shape[:k]
m = prod(shape)
square_gramian = reshape_gramian(generalized_gramian, [m])
assert is_psd_matrix(square_gramian)
weights_vector = self.weighting(square_gramian)
weights = weights_vector.reshape(shape)
return weights
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import Aggregator
from ._utils.non_differentiable import raise_non_differentiable_error

Expand Down Expand Up @@ -38,8 +40,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Tensor) -> Tensor:
self._check_is_matrix(matrix)
def forward(self, matrix: Matrix) -> Tensor:
self._check_matrix_has_enough_rows(matrix)

if matrix.shape[0] == 0 or matrix.shape[1] == 0:
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import torch
from torch import Tensor

Expand Down Expand Up @@ -32,7 +34,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
device = gramian.device
dtype = gramian.dtype
cpu = torch.device("cpu")
gramian = gramian.to(device=cpu)
gramian = cast(PSDMatrix, gramian.to(device=cpu))

dimension = gramian.shape[0]
weights = torch.zeros(dimension, device=cpu, dtype=dtype)
Expand Down
1 change: 0 additions & 1 deletion src/torchjd/aggregation/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, trim_number: int):
self.trim_number = trim_number

def forward(self, matrix: Tensor) -> Tensor:
self._check_is_matrix(matrix)
self._check_matrix_has_enough_rows(matrix)

n_rows = matrix.shape[0]
Expand Down
12 changes: 8 additions & 4 deletions src/torchjd/aggregation/_utils/gramian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from torchjd._linalg.matrix import PSDMatrix
from torchjd._linalg import PSDMatrix, is_psd_matrix


def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
Expand All @@ -13,9 +13,11 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""
squared_frobenius_norm = gramian.diagonal().sum()
if squared_frobenius_norm < eps:
return torch.zeros_like(gramian)
output = torch.zeros_like(gramian)
else:
return gramian / squared_frobenius_norm
output = gramian / squared_frobenius_norm
assert is_psd_matrix(output)
return output


def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
Expand All @@ -30,4 +32,6 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
)
return gramian + regularization_matrix
output = gramian + regularization_matrix
assert is_psd_matrix(output)
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return output
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_utils/pref_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import Tensor

from torchjd._linalg.matrix import Matrix
from torchjd._linalg import Matrix
from torchjd.aggregation._constant import ConstantWeighting
from torchjd.aggregation._weighting_bases import Weighting

Expand Down
5 changes: 4 additions & 1 deletion src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from torch import Tensor, nn

from torchjd._linalg import PSDQuadraticForm, is_psd_quadratic_form

_T = TypeVar("_T", contravariant=True)
_FnInputT = TypeVar("_FnInputT")
_FnOutputT = TypeVar("_FnOutputT")
Expand Down Expand Up @@ -64,7 +66,7 @@ def __init__(self):
super().__init__()

@abstractmethod
def forward(self, generalized_gramian: Tensor) -> Tensor:
def forward(self, generalized_gramian: PSDQuadraticForm) -> Tensor:
"""Computes the vector of weights from the input generalized Gramian."""

# Override to make type hints and documentation more specific
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
Expand All @@ -74,4 +76,5 @@ def __call__(self, generalized_gramian: Tensor) -> Tensor:
hooks.
"""

assert is_psd_quadratic_form(generalized_gramian)
return super().__call__(generalized_gramian)
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor, nn, vmap
from torch.autograd.graph import get_gradient_edge

from torchjd._linalg.matrix import PSDMatrix
from torchjd._linalg import PSDMatrix

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_gramian_accumulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from torchjd._linalg.matrix import PSDMatrix
from torchjd._linalg import PSDMatrix


class GramianAccumulator:
Expand Down
13 changes: 7 additions & 6 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from torch import Tensor
from torch.utils._pytree import PyTree

from torchjd._linalg import compute_gramian
from torchjd._linalg.matrix import PSDMatrix
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix
from torchjd.autogram._jacobian_computer import JacobianComputer


Expand All @@ -23,12 +22,12 @@ def __call__(
def track_forward_call(self) -> None:
"""Track that the module's forward was called. Necessary in some implementations."""

def reset(self):
def reset(self) -> None:
"""Reset state if any. Necessary in some implementations."""


class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer):
def __init__(self, jacobian_computer: JacobianComputer):
self.jacobian_computer = jacobian_computer


Expand All @@ -41,7 +40,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
def __init__(self, jacobian_computer: JacobianComputer):
super().__init__(jacobian_computer)
self.remaining_counter = 0
self.summed_jacobian: Optional[Tensor] = None
self.summed_jacobian: Optional[Matrix] = None

def reset(self) -> None:
self.remaining_counter = 0
Expand All @@ -64,7 +63,9 @@ def __call__(
if self.summed_jacobian is None:
self.summed_jacobian = jacobian_matrix
else:
self.summed_jacobian += jacobian_matrix
jacobians_sum = self.summed_jacobian + jacobian_matrix
assert is_matrix(jacobians_sum)
self.summed_jacobian = jacobians_sum

self.remaining_counter -= 1

Expand Down
Loading
Loading