@@ -2,29 +2,29 @@ Grouping
22========
33
44The aggregation can be made independently on groups of parameters, at different granularities. The
5- [ Gradient Vaccine paper]( https://arxiv.org/pdf/2010.05874) introduces four strategies to partition
5+ ` Gradient Vaccine paper < https://arxiv.org/pdf/2010.05874 >`_ introduces four strategies to partition
66the parameters:
77
8- 1. **Together ** (baseline): one group covering all shared parameters. Corresponds to the
9- ` whole_model ` stategy in the paper.
8+ 1. **Together ** (baseline): one group covering all parameters. Corresponds to the ` whole_model `
9+ stategy in the paper.
1010
11112. **Per network **: one group per top-level sub-network (e.g. encoder and decoder separately).
1212 Corresponds to the `enc_dec ` stategy in the paper.
1313
14- 3. **Per layer **: one group per leaf module of the encoder . Corresponds to the `all_layer ` stategy
14+ 3. **Per layer **: one group per leaf module of the network . Corresponds to the `all_layer ` stategy
1515 in the paper.
1616
17174. **Per tensor **: one group per individual parameter tensor. Corresponds to the `all_matrix `
1818 stategy in the paper.
1919
2020In TorchJD, grouping is achieved by calling :func: `~torchjd.autojac.jac_to_grad ` once per group
21- after :func: `~torchjd.autojac.mtl_backward `, with a dedicated aggregator instance per group.
22- For :class: `~torchjd.aggregation.Stateful ` aggregators, each instance independently maintains its
23- own state (e.g. the EMA :math: `\hat {\phi }` state in :class: ` ~torchjd.aggregation.GradVac `), matching
24- the per-block targets from the original paper.
21+ after :func: `~torchjd.autojac.backward ` or :func: ` ~torchjd.autojac. mtl_backward `, with a dedicated
22+ aggregator instance per group. For :class: `~torchjd.aggregation.Stateful ` aggregators, each instance
23+ should independently maintains its own state (e.g. the EMA :math: `\hat {\phi }` state in
24+ :class: ` ~torchjd.aggregation.GradVac `, matching the per-block targets from the original paper) .
2525
2626.. note ::
27- The grouping is orthogonal to the choice of
27+ The grouping is orthogonal to the choice between
2828 :func: `~torchjd.autojac.backward ` vs :func: `~torchjd.autojac.mtl_backward `. Those functions
2929 determine *which * parameters receive Jacobians; grouping then determines *how * those Jacobians
3030 are partitioned for aggregation.
0 commit comments