Skip to content

Commit 98d1d1b

Browse files
committed
Split structure extraction and matrix ignoring from weight extraction
- The idea is to make it explicit that some weightings are based on the structure of the matrix (sum, mean, random) and that some are independent of their input (constant), by separating the part where we extract the structure (or the none for constant) and the part where we use this structure (or the none for constant) to make a vector of weigths.
1 parent d33cdf5 commit 98d1d1b

8 files changed

Lines changed: 88 additions & 61 deletions

File tree

src/torchjd/_linalg/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from ._generalized_gramian import flatten, movedim, reshape
22
from ._gramian import compute_gramian, normalize, regularize
33
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
4+
from ._structure import Structure, extract_structure
45

56
__all__ = [
7+
"extract_structure",
8+
"Structure",
69
"compute_gramian",
710
"normalize",
811
"regularize",

src/torchjd/_linalg/_structure.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from attr import dataclass
3+
4+
from torchjd._linalg import Matrix
5+
6+
7+
@dataclass
8+
class Structure:
9+
m: int
10+
device: torch.device
11+
dtype: torch.dtype
12+
13+
14+
def extract_structure(matrix: Matrix) -> Structure:
15+
return Structure(m=matrix.shape[0], device=matrix.device, dtype=matrix.dtype)

src/torchjd/aggregation/_constant.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from torch import Tensor
22

3-
from torchjd._linalg import Matrix
3+
from torchjd.aggregation._weighting_bases import FromNothingWeighting
44

55
from ._aggregator_bases import WeightedAggregator
66
from ._utils.str import vector_to_str
77
from ._weighting_bases import Weighting
88

99

10-
class ConstantWeighting(Weighting[Matrix]):
10+
class _ConstantWeighting(Weighting[None]):
1111
"""
1212
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
1313
weights.
@@ -25,16 +25,13 @@ def __init__(self, weights: Tensor) -> None:
2525
super().__init__()
2626
self.weights = weights
2727

28-
def forward(self, matrix: Tensor, /) -> Tensor:
29-
self._check_matrix_shape(matrix)
28+
def forward(self, _: None, /) -> Tensor:
3029
return self.weights
3130

32-
def _check_matrix_shape(self, matrix: Tensor) -> None:
33-
if matrix.shape[0] != len(self.weights):
34-
raise ValueError(
35-
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
36-
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
37-
)
31+
32+
class ConstantWeighting(FromNothingWeighting):
33+
def __init__(self, weights: Tensor) -> None:
34+
super().__init__(_ConstantWeighting(weights))
3835

3936

4037
class Constant(WeightedAggregator):

src/torchjd/aggregation/_mean.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
import torch
22
from torch import Tensor
33

4-
from torchjd._linalg import Matrix
4+
from torchjd._linalg import Structure
5+
from torchjd.aggregation._weighting_bases import FromStructureWeighting
56

67
from ._aggregator_bases import WeightedAggregator
78
from ._weighting_bases import Weighting
89

910

10-
class MeanWeighting(Weighting[Matrix]):
11+
class _MeanWeighting(Weighting[Structure]):
1112
r"""
1213
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
1314
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
1415
\mathbb{R}^m`.
1516
"""
1617

17-
def forward(self, matrix: Tensor, /) -> Tensor:
18-
device = matrix.device
19-
dtype = matrix.dtype
20-
m = matrix.shape[0]
18+
def forward(self, structure: Structure, /) -> Tensor:
19+
device = structure.device
20+
dtype = structure.dtype
21+
m = structure.m
2122
weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
2223
return weights
2324

2425

26+
class MeanWeighting(FromStructureWeighting):
27+
def __init__(self) -> None:
28+
super().__init__(_MeanWeighting())
29+
30+
2531
class Mean(WeightedAggregator):
2632
"""
2733
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input

src/torchjd/aggregation/_random.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,30 @@
22
from torch import Tensor
33
from torch.nn import functional as F
44

5-
from torchjd._linalg import Matrix
5+
from torchjd._linalg import Structure
6+
from torchjd.aggregation._weighting_bases import FromStructureWeighting
67

78
from ._aggregator_bases import WeightedAggregator
89
from ._weighting_bases import Weighting
910

1011

11-
class RandomWeighting(Weighting[Matrix]):
12+
class _RandomWeighting(Weighting[Structure]):
1213
"""
1314
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
1415
at each call.
1516
"""
1617

17-
def forward(self, matrix: Tensor, /) -> Tensor:
18-
random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
18+
def forward(self, structure: Structure, /) -> Tensor:
19+
random_vector = torch.randn(structure.m, device=structure.device, dtype=structure.dtype)
1920
weights = F.softmax(random_vector, dim=-1)
2021
return weights
2122

2223

24+
class RandomWeighting(FromStructureWeighting):
25+
def __init__(self) -> None:
26+
super().__init__(_RandomWeighting())
27+
28+
2329
class Random(WeightedAggregator):
2430
"""
2531
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of

src/torchjd/aggregation/_sum.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
import torch
22
from torch import Tensor
33

4-
from torchjd._linalg import Matrix
4+
from torchjd._linalg import Structure
5+
from torchjd.aggregation._weighting_bases import FromStructureWeighting
56

67
from ._aggregator_bases import WeightedAggregator
78
from ._weighting_bases import Weighting
89

910

10-
class SumWeighting(Weighting[Matrix]):
11-
r"""
12-
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
13-
:math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
14-
"""
15-
16-
def forward(self, matrix: Tensor, /) -> Tensor:
17-
device = matrix.device
18-
dtype = matrix.dtype
19-
weights = torch.ones(matrix.shape[0], device=device, dtype=dtype)
11+
class _SumWeighting(Weighting[Structure]):
12+
def forward(self, structure: Structure, /) -> Tensor:
13+
weights = torch.ones(structure.m, device=structure.device, dtype=structure.dtype)
2014
return weights
2115

2216

17+
class SumWeighting(FromStructureWeighting):
18+
def __init__(self) -> None:
19+
super().__init__(_SumWeighting())
20+
21+
2322
class Sum(WeightedAggregator):
2423
"""
2524
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input

src/torchjd/aggregation/_weighting_bases.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from torch import Tensor, nn
88

9-
from torchjd._linalg import PSDTensor, is_psd_tensor
9+
from torchjd._linalg import Matrix, PSDTensor, Structure, extract_structure, is_psd_tensor
1010

11-
_T = TypeVar("_T", contravariant=True, bound=Tensor)
12-
_FnInputT = TypeVar("_FnInputT", bound=Tensor)
13-
_FnOutputT = TypeVar("_FnOutputT", bound=Tensor)
11+
_T = TypeVar("_T", contravariant=True)
12+
_FnInputT = TypeVar("_FnInputT")
13+
_FnOutputT = TypeVar("_FnOutputT")
1414

1515

1616
class Weighting(nn.Module, ABC, Generic[_T]):
@@ -27,11 +27,9 @@ def __init__(self) -> None:
2727
def forward(self, stat: _T, /) -> Tensor:
2828
"""Computes the vector of weights from the input stat."""
2929

30-
def __call__(self, stat: Tensor, /) -> Tensor:
30+
def __call__(self, stat: object, /) -> Tensor:
3131
"""Computes the vector of weights from the input stat and applies all registered hooks."""
3232

33-
# The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of
34-
# stat to be Tensor.
3533
return super().__call__(stat)
3634

3735
def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]:
@@ -55,6 +53,32 @@ def forward(self, stat: _T, /) -> Tensor:
5553
return self.weighting(self.fn(stat))
5654

5755

56+
class FromStructureWeighting(_Composition[Matrix]):
57+
"""
58+
Weighting that extracts the structure of the input matrix before applying a Weighting to it.
59+
60+
:param structure_weighting: The object responsible for extracting the vector of weights from the
61+
structure.
62+
"""
63+
64+
def __init__(self, structure_weighting: Weighting[Structure]) -> None:
65+
super().__init__(structure_weighting, extract_structure)
66+
self.structure_weighting = structure_weighting
67+
68+
69+
class FromNothingWeighting(_Composition[Matrix]):
70+
"""
71+
Weighting that extracts nothing from the input matrix before applying a Weighting to it (i.e. to
72+
None).
73+
74+
:param none_weighting: The object responsible for extracting the vector of weights from nothing.
75+
"""
76+
77+
def __init__(self, none_weighting: Weighting[None]) -> None:
78+
super().__init__(none_weighting, lambda _: None)
79+
self.none_weighting = none_weighting
80+
81+
5882
class GeneralizedWeighting(nn.Module, ABC):
5983
r"""
6084
Abstract base class for all weightings that operate on generalized Gramians. It has the role of

tests/unit/aggregation/test_constant.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,29 +63,6 @@ def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionCon
6363
_ = Constant(weights=weights)
6464

6565

66-
@mark.parametrize(
67-
["weights_shape", "n_rows", "expectation"],
68-
[
69-
([0], 0, does_not_raise()),
70-
([1], 1, does_not_raise()),
71-
([5], 5, does_not_raise()),
72-
([0], 1, raises(ValueError)),
73-
([1], 0, raises(ValueError)),
74-
([4], 5, raises(ValueError)),
75-
([5], 4, raises(ValueError)),
76-
],
77-
)
78-
def test_matrix_shape_check(
79-
weights_shape: list[int], n_rows: int, expectation: ExceptionContext
80-
) -> None:
81-
matrix = ones_([n_rows, 5])
82-
weights = ones_(weights_shape)
83-
aggregator = Constant(weights)
84-
85-
with expectation:
86-
_ = aggregator(matrix)
87-
88-
8966
def test_representations() -> None:
9067
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
9168
assert repr(A) == "Constant(weights=tensor([1., 2.]))"

0 commit comments

Comments
 (0)