Skip to content

Commit 787f486

Browse files
authored
Merge branch 'main' into feature/interactive-plotting-ui
2 parents 6a78932 + 012b1ba commit 787f486

2 files changed

Lines changed: 6 additions & 25 deletions

File tree

docs/source/docs/aggregation/gradvac.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ GradVac
66
.. autoclass:: torchjd.aggregation.GradVac
77
:members:
88
:undoc-members:
9-
:exclude-members: forward
9+
:exclude-members: forward, eps, beta
1010

1111
.. autoclass:: torchjd.aggregation.GradVacWeighting
1212
:members:
1313
:undoc-members:
14-
:exclude-members: forward
14+
:exclude-members: forward, eps, beta

src/torchjd/aggregation/_gradvac.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,13 @@ class GradVac(GramianWeightedAggregator):
3131
This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
3232
the number of tasks or dtype changes.
3333
34-
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign
35-
the :attr:`beta` attribute between steps to tune the EMA update.
36-
:param eps: Small non-negative constant added to denominators when computing cosines and the
37-
vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or
38-
assign the :attr:`eps` attribute between steps to tune numerical behavior.
34+
:param beta: EMA decay for :math:`\hat{\phi}`.
35+
:param eps: Small non-negative constant added to denominators.
3936
4037
.. note::
4138
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
4239
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
4340
you need reproducibility.
44-
45-
.. note::
46-
To apply GradVac with per-layer or per-parameter-group granularity, create a separate
47-
:class:`GradVac` instance for each group and call
48-
:func:`~torchjd.autojac.jac_to_grad` once per group after
49-
:func:`~torchjd.autojac.mtl_backward`. Each instance maintains its own EMA state,
50-
matching the per-block targets :math:`\hat{\phi}_{ijk}` from the original paper. See
51-
the :doc:`Grouping </examples/grouping>` example for details.
5241
"""
5342

5443
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
@@ -59,8 +48,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
5948

6049
@property
6150
def beta(self) -> float:
62-
"""EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``)."""
63-
6451
return self._gradvac_weighting.beta
6552

6653
@beta.setter
@@ -69,8 +56,6 @@ def beta(self, value: float) -> None:
6956

7057
@property
7158
def eps(self) -> float:
72-
"""Small non-negative constant added to denominators for numerical stability."""
73-
7459
return self._gradvac_weighting.eps
7560

7661
@eps.setter
@@ -107,8 +92,8 @@ class GradVacWeighting(Weighting[PSDMatrix]):
10792
This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
10893
the number of tasks or dtype changes.
10994
110-
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``).
111-
:param eps: Small non-negative constant added to denominators (default ``1e-8``).
95+
:param beta: EMA decay for :math:`\hat{\phi}`.
96+
:param eps: Small non-negative constant added to denominators.
11297
"""
11398

11499
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
@@ -125,8 +110,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
125110

126111
@property
127112
def beta(self) -> float:
128-
"""EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``)."""
129-
130113
return self._beta
131114

132115
@beta.setter
@@ -137,8 +120,6 @@ def beta(self, value: float) -> None:
137120

138121
@property
139122
def eps(self) -> float:
140-
"""Small non-negative constant added to denominators for numerical stability."""
141-
142123
return self._eps
143124

144125
@eps.setter

0 commit comments

Comments
 (0)