Skip to content

Commit d62a474

Browse files
committed
Simplify parameter explanation
- The goal here is to uniformize a bit with the rest of the library: it's not really needed to indicate default values, because they already appear in the built documentation (unless the default is some None that later gets transformed into something else, but it's not the case here). Also, I don't think it's needed to indicate that these parameters can be changed afterwards, because I don't think a lot of people will do that, and it's actually the case of the parameters of all aggregators (or it should be). Lastly, I made the description be the same between aggregator and weighting (for ease of maintainance).
1 parent f028d54 commit d62a474

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

src/torchjd/aggregation/_gradvac.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,8 @@ class GradVac(GramianWeightedAggregator):
3131
This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
3232
the number of tasks or dtype changes.
3333
34-
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign
35-
the :attr:`beta` attribute between steps to tune the EMA update.
36-
:param eps: Small non-negative constant added to denominators when computing cosines and the
37-
vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or
38-
assign the :attr:`eps` attribute between steps to tune numerical behavior.
34+
:param beta: EMA decay for :math:`\hat{\phi}`.
35+
:param eps: Small non-negative constant added to denominators.
3936
4037
.. note::
4138
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
@@ -95,8 +92,8 @@ class GradVacWeighting(Weighting[PSDMatrix]):
9592
This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
9693
the number of tasks or dtype changes.
9794
98-
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``).
99-
:param eps: Small non-negative constant added to denominators (default ``1e-8``).
95+
:param beta: EMA decay for :math:`\hat{\phi}`.
96+
:param eps: Small non-negative constant added to denominators.
10097
"""
10198

10299
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:

0 commit comments

Comments
 (0)