You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> implementing the aggregation step of
301
302
Gradient Vaccine (GradVac) from <aclass="reference external" href="https://openreview.net/forum?id=F1vEjWK-lH_">Gradient Vaccine: Investigating and Improving Multi-task
302
303
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)</a>.</p>
303
304
<p>For each task <spanclass="math notranslate nohighlight">\(i\)</span>, the order in which other tasks <spanclass="math notranslate nohighlight">\(j\)</span> are visited is drawn at
@@ -326,7 +327,7 @@ <h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading">¶
<spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L65-L68"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVac.reset" title="Link to this definition">¶</a></dt>
330
+
<spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L67-L70"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVac.reset" title="Link to this definition">¶</a></dt>
330
331
<dd><p>Clears EMA state so the next forward starts from zero targets.</p>
<spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L131-L135"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVacWeighting.reset" title="Link to this definition">¶</a></dt>
370
+
<spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L134-L138"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVacWeighting.reset" title="Link to this definition">¶</a></dt>
369
371
<dd><p>Clears EMA state so the next forward starts from zero targets.</p>
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">Stateful</span></span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L4-L9"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.Stateful" title="Link to this definition">¶</a></dt>
<spanclass="property"><spanclass="k"><spanclass="pre">abstractmethod</span></span><spanclass="w"></span></span><spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L7-L9"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.Stateful.reset" title="Link to this definition">¶</a></dt>
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">NashMTL</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">n_tasks</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">max_norm</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">1.0</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">update_weights_every</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">1</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">optim_niter</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">20</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L23-L83"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.NashMTL" title="Link to this definition">¶</a></dt>
300
-
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> as proposed in Algorithm 1 of
299
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">NashMTL</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">n_tasks</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">max_norm</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">1.0</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">update_weights_every</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">1</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">optim_niter</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">20</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L24-L85"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.NashMTL" title="Link to this definition">¶</a></dt>
<aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> as proposed in Algorithm 1 of
301
302
<aclass="reference external" href="https://arxiv.org/pdf/2202.01017.pdf">Multi-Task Learning as a Bargaining Game</a>.</p>
<spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L75-L77"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.NashMTL.reset" title="Link to this definition">¶</a></dt>
335
+
<spanclass="sig-name descname"><spanclass="pre">reset</span></span><spanclass="sig-paren">(</span><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L77-L79"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.NashMTL.reset" title="Link to this definition">¶</a></dt>
335
336
<dd><p>Resets the internal state of the algorithm.</p>
0 commit comments