Skip to content

Commit ba21fa0

Browse files
committed
Fix documentation
1 parent c139fc1 commit ba21fa0

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

src/torchjd/aggregation/_constant.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,6 @@
88

99

1010
class _ConstantWeighting(Weighting[None]):
11-
"""
12-
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
13-
weights.
14-
15-
:param weights: The weights to return at each call.
16-
"""
17-
1811
def __init__(self, weights: Tensor) -> None:
1912
if weights.dim() != 1:
2013
raise ValueError(
@@ -30,6 +23,13 @@ def forward(self, _: None, /) -> Tensor:
3023

3124

3225
class ConstantWeighting(FromNothingWeighting):
26+
"""
27+
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
28+
weights.
29+
30+
:param weights: The weights to return at each call.
31+
"""
32+
3333
def __init__(self, weights: Tensor) -> None:
3434
super().__init__(_ConstantWeighting(weights))
3535

src/torchjd/aggregation/_mean.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99

1010

1111
class _MeanWeighting(Weighting[Structure]):
12-
r"""
13-
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
14-
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
15-
\mathbb{R}^m`.
16-
"""
17-
1812
def forward(self, structure: Structure, /) -> Tensor:
1913
device = structure.device
2014
dtype = structure.dtype
@@ -24,6 +18,12 @@ def forward(self, structure: Structure, /) -> Tensor:
2418

2519

2620
class MeanWeighting(FromStructureWeighting):
21+
r"""
22+
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
23+
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
24+
\mathbb{R}^m`.
25+
"""
26+
2727
def __init__(self) -> None:
2828
super().__init__(_MeanWeighting())
2929

src/torchjd/aggregation/_random.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010

1111

1212
class _RandomWeighting(Weighting[Structure]):
13-
"""
14-
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
15-
at each call.
16-
"""
17-
1813
def forward(self, structure: Structure, /) -> Tensor:
1914
random_vector = torch.randn(structure.m, device=structure.device, dtype=structure.dtype)
2015
weights = F.softmax(random_vector, dim=-1)
2116
return weights
2217

2318

2419
class RandomWeighting(FromStructureWeighting):
20+
"""
21+
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
22+
at each call.
23+
"""
24+
2525
def __init__(self) -> None:
2626
super().__init__(_RandomWeighting())
2727

src/torchjd/aggregation/_sum.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def forward(self, structure: Structure, /) -> Tensor:
1515

1616

1717
class SumWeighting(FromStructureWeighting):
18+
r"""
19+
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
20+
:math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
21+
"""
22+
1823
def __init__(self) -> None:
1924
super().__init__(_SumWeighting())
2025

0 commit comments

Comments
 (0)