@@ -31,24 +31,13 @@ class GradVac(GramianWeightedAggregator):
3131 This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
3232 the number of tasks or dtype changes.
3333
34- :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign
35- the :attr:`beta` attribute between steps to tune the EMA update.
36- :param eps: Small non-negative constant added to denominators when computing cosines and the
37- vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or
38- assign the :attr:`eps` attribute between steps to tune numerical behavior.
34+ :param beta: EMA decay for :math:`\hat{\phi}`.
35+ :param eps: Small non-negative constant added to denominators.
3936
4037 .. note::
4138 For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
4239 using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
4340 you need reproducibility.
44-
45- .. note::
46- To apply GradVac with per-layer or per-parameter-group granularity, create a separate
47- :class:`GradVac` instance for each group and call
48- :func:`~torchjd.autojac.jac_to_grad` once per group after
49- :func:`~torchjd.autojac.mtl_backward`. Each instance maintains its own EMA state,
50- matching the per-block targets :math:`\hat{\phi}_{ijk}` from the original paper. See
51- the :doc:`Grouping </examples/grouping>` example for details.
5241 """
5342
5443 def __init__ (self , beta : float = 0.5 , eps : float = 1e-8 ) -> None :
@@ -59,8 +48,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
5948
6049 @property
6150 def beta (self ) -> float :
62- """EMA decay coefficient for :math:`\\ hat{\\ phi}` (paper default ``0.5``)."""
63-
6451 return self ._gradvac_weighting .beta
6552
6653 @beta .setter
@@ -69,8 +56,6 @@ def beta(self, value: float) -> None:
6956
7057 @property
7158 def eps (self ) -> float :
72- """Small non-negative constant added to denominators for numerical stability."""
73-
7459 return self ._gradvac_weighting .eps
7560
7661 @eps .setter
@@ -107,8 +92,8 @@ class GradVacWeighting(Weighting[PSDMatrix]):
10792 This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
10893 the number of tasks or dtype changes.
10994
110- :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``) .
111- :param eps: Small non-negative constant added to denominators (default ``1e-8``) .
95+ :param beta: EMA decay for :math:`\hat{\phi}`.
96+ :param eps: Small non-negative constant added to denominators.
11297 """
11398
11499 def __init__ (self , beta : float = 0.5 , eps : float = 1e-8 ) -> None :
@@ -125,8 +110,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
125110
126111 @property
127112 def beta (self ) -> float :
128- """EMA decay coefficient for :math:`\\ hat{\\ phi}` (paper default ``0.5``)."""
129-
130113 return self ._beta
131114
132115 @beta .setter
@@ -137,8 +120,6 @@ def beta(self, value: float) -> None:
137120
138121 @property
139122 def eps (self ) -> float :
140- """Small non-negative constant added to denominators for numerical stability."""
141-
142123 return self ._eps
143124
144125 @eps .setter
0 commit comments