Skip to content

Commit df75f00

Browse files
committed
Fix link, a few more improvements / simplifications
1 parent d4797fc commit df75f00

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

docs/source/examples/grouping.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,29 @@ Grouping
22
========
33

44
The aggregation can be made independently on groups of parameters, at different granularities. The
5-
[Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four strategies to partition
5+
`Gradient Vaccine paper <https://arxiv.org/pdf/2010.05874>`_ introduces four strategies to partition
66
the parameters:
77

8-
1. **Together** (baseline): one group covering all shared parameters. Corresponds to the
9-
`whole_model` stategy in the paper.
8+
1. **Together** (baseline): one group covering all parameters. Corresponds to the `whole_model`
9+
stategy in the paper.
1010

1111
2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately).
1212
Corresponds to the `enc_dec` stategy in the paper.
1313

14-
3. **Per layer**: one group per leaf module of the encoder. Corresponds to the `all_layer` stategy
14+
3. **Per layer**: one group per leaf module of the network. Corresponds to the `all_layer` stategy
1515
in the paper.
1616

1717
4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix`
1818
stategy in the paper.
1919

2020
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
21-
after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group.
22-
For :class:`~torchjd.aggregation.Stateful` aggregators, each instance independently maintains its
23-
own state (e.g. the EMA :math:`\hat{\phi}` state in :class:`~torchjd.aggregation.GradVac`), matching
24-
the per-block targets from the original paper.
21+
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
22+
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
23+
should independently maintains its own state (e.g. the EMA :math:`\hat{\phi}` state in
24+
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).
2525

2626
.. note::
27-
The grouping is orthogonal to the choice of
27+
The grouping is orthogonal to the choice between
2828
:func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions
2929
determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians
3030
are partitioned for aggregation.

0 commit comments

Comments
 (0)