Skip to content

Commit b1dc311

Browse files
committed
1 parent b21c65d commit b1dc311

7 files changed

Lines changed: 43 additions & 13 deletions

File tree

latest/_sources/docs/aggregation/index.rst.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Abstract base classes
2222
:undoc-members:
2323
:exclude-members: forward
2424

25+
.. autoclass:: torchjd.aggregation.Stateful
26+
:members:
27+
:undoc-members:
28+
2529

2630
.. toctree::
2731
:hidden:

latest/docs/aggregation/gradvac/index.html

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,9 @@
296296
<h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading"></a></h1>
297297
<dl class="py class">
298298
<dt class="sig sig-object py" id="torchjd.aggregation.GradVac">
299-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">GradVac</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1e-08</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L15-L71"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVac" title="Link to this definition"></a></dt>
300-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> implementing the aggregation step of
299+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">GradVac</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1e-08</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L16-L73"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVac" title="Link to this definition"></a></dt>
300+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Stateful" title="torchjd.aggregation._mixins.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>
301+
<a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> implementing the aggregation step of
301302
Gradient Vaccine (GradVac) from <a class="reference external" href="https://openreview.net/forum?id=F1vEjWK-lH_">Gradient Vaccine: Investigating and Improving Multi-task
302303
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)</a>.</p>
303304
<p>For each task <span class="math notranslate nohighlight">\(i\)</span>, the order in which other tasks <span class="math notranslate nohighlight">\(j\)</span> are visited is drawn at
@@ -326,7 +327,7 @@ <h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading">¶
326327
</div>
327328
<dl class="py method">
328329
<dt class="sig sig-object py" id="torchjd.aggregation.GradVac.reset">
329-
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L65-L68"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVac.reset" title="Link to this definition"></a></dt>
330+
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L67-L70"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVac.reset" title="Link to this definition"></a></dt>
330331
<dd><p>Clears EMA state so the next forward starts from zero targets.</p>
331332
<dl class="field-list simple">
332333
<dt class="field-odd">Return type<span class="colon">:</span></dt>
@@ -339,8 +340,9 @@ <h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading">¶
339340

340341
<dl class="py class">
341342
<dt class="sig sig-object py" id="torchjd.aggregation.GradVacWeighting">
342-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">GradVacWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1e-08</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L74-L190"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVacWeighting" title="Link to this definition"></a></dt>
343-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
343+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">GradVacWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1e-08</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L76-L193"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVacWeighting" title="Link to this definition"></a></dt>
344+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Stateful" title="torchjd.aggregation._mixins.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>
345+
<a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
344346
<a class="reference internal" href="#torchjd.aggregation.GradVac" title="torchjd.aggregation.GradVac"><code class="xref py py-class docutils literal notranslate"><span class="pre">GradVac</span></code></a>.</p>
345347
<p>All required quantities (gradient norms, cosine similarities, and their updates after the
346348
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
@@ -365,7 +367,7 @@ <h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading">¶
365367
</dl>
366368
<dl class="py method">
367369
<dt class="sig sig-object py" id="torchjd.aggregation.GradVacWeighting.reset">
368-
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L131-L135"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVacWeighting.reset" title="Link to this definition"></a></dt>
370+
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_gradvac.py#L134-L138"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVacWeighting.reset" title="Link to this definition"></a></dt>
369371
<dd><p>Clears EMA state so the next forward starts from zero targets.</p>
370372
<dl class="field-list simple">
371373
<dt class="field-odd">Return type<span class="colon">:</span></dt>

latest/docs/aggregation/index.html

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,23 @@ <h2>Abstract base classes<a class="headerlink" href="#abstract-base-classes" tit
377377
<span class="math notranslate nohighlight">\(m_1 \times \dots \times m_k \times m_k \times \dots \times m_1\)</span>.</p>
378378
</dd></dl>
379379

380+
<dl class="py class">
381+
<dt class="sig sig-object py" id="torchjd.aggregation.Stateful">
382+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">Stateful</span></span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L4-L9"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.Stateful" title="Link to this definition"></a></dt>
383+
<dd><p>Mixin adding a reset method.</p>
384+
<dl class="py method">
385+
<dt class="sig sig-object py" id="torchjd.aggregation.Stateful.reset">
386+
<span class="property"><span class="k"><span class="pre">abstractmethod</span></span><span class="w"> </span></span><span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L7-L9"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.Stateful.reset" title="Link to this definition"></a></dt>
387+
<dd><p>Resets the internal state.</p>
388+
<dl class="field-list simple">
389+
<dt class="field-odd">Return type<span class="colon">:</span></dt>
390+
<dd class="field-odd"><p><span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span></p>
391+
</dd>
392+
</dl>
393+
</dd></dl>
394+
395+
</dd></dl>
396+
380397
<div class="toctree-wrapper compound">
381398
</div>
382399
</section>
@@ -442,6 +459,10 @@ <h2>Abstract base classes<a class="headerlink" href="#abstract-base-classes" tit
442459
<li><a class="reference internal" href="#torchjd.aggregation.Aggregator"><code class="docutils literal notranslate"><span class="pre">Aggregator</span></code></a></li>
443460
<li><a class="reference internal" href="#torchjd.aggregation.Weighting"><code class="docutils literal notranslate"><span class="pre">Weighting</span></code></a></li>
444461
<li><a class="reference internal" href="#torchjd.aggregation.GeneralizedWeighting"><code class="docutils literal notranslate"><span class="pre">GeneralizedWeighting</span></code></a></li>
462+
<li><a class="reference internal" href="#torchjd.aggregation.Stateful"><code class="docutils literal notranslate"><span class="pre">Stateful</span></code></a><ul>
463+
<li><a class="reference internal" href="#torchjd.aggregation.Stateful.reset"><code class="docutils literal notranslate"><span class="pre">Stateful.reset()</span></code></a></li>
464+
</ul>
465+
</li>
445466
</ul>
446467
</li>
447468
</ul>

latest/docs/aggregation/nash_mtl/index.html

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,9 @@
296296
<h1>Nash-MTL<a class="headerlink" href="#nash-mtl" title="Link to this heading"></a></h1>
297297
<dl class="py class">
298298
<dt class="sig sig-object py" id="torchjd.aggregation.NashMTL">
299-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">NashMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">n_tasks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_norm</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_weights_every</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optim_niter</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">20</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L23-L83"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.NashMTL" title="Link to this definition"></a></dt>
300-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as proposed in Algorithm 1 of
299+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">NashMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">n_tasks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_norm</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_weights_every</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optim_niter</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">20</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L24-L85"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.NashMTL" title="Link to this definition"></a></dt>
300+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Stateful" title="torchjd.aggregation._mixins.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>
301+
<a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as proposed in Algorithm 1 of
301302
<a class="reference external" href="https://arxiv.org/pdf/2202.01017.pdf">Multi-Task Learning as a Bargaining Game</a>.</p>
302303
<dl class="field-list simple">
303304
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -331,7 +332,7 @@ <h1>Nash-MTL<a class="headerlink" href="#nash-mtl" title="Link to this heading">
331332
</div>
332333
<dl class="py method">
333334
<dt class="sig sig-object py" id="torchjd.aggregation.NashMTL.reset">
334-
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L75-L77"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.NashMTL.reset" title="Link to this definition"></a></dt>
335+
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_nash_mtl.py#L77-L79"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.NashMTL.reset" title="Link to this definition"></a></dt>
335336
<dd><p>Resets the internal state of the algorithm.</p>
336337
<dl class="field-list simple">
337338
<dt class="field-odd">Return type<span class="colon">:</span></dt>

latest/genindex/index.html

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,14 +494,14 @@ <h2>R</h2>
494494
</li>
495495
<li><a href="../docs/aggregation/random/#torchjd.aggregation.RandomWeighting">RandomWeighting (class in torchjd.aggregation)</a>
496496
</li>
497-
</ul></td>
498-
<td style="width: 33%; vertical-align: top;"><ul>
499497
<li><a href="../docs/aggregation/gradvac/#torchjd.aggregation.GradVac.reset">reset() (torchjd.aggregation.GradVac method)</a>
500498

501499
<ul>
502500
<li><a href="../docs/aggregation/gradvac/#torchjd.aggregation.GradVacWeighting.reset">(torchjd.aggregation.GradVacWeighting method)</a>
503501
</li>
504502
<li><a href="../docs/aggregation/nash_mtl/#torchjd.aggregation.NashMTL.reset">(torchjd.aggregation.NashMTL method)</a>
503+
</li>
504+
<li><a href="../docs/aggregation/#torchjd.aggregation.Stateful.reset">(torchjd.aggregation.Stateful method)</a>
505505
</li>
506506
</ul></li>
507507
</ul></td>
@@ -512,10 +512,12 @@ <h2>R</h2>
512512
<h2>S</h2>
513513
<table style="width: 100%" class="indextable genindextable"><tr>
514514
<td style="width: 33%; vertical-align: top;"><ul>
515-
<li><a href="../docs/aggregation/sum/#torchjd.aggregation.Sum">Sum (class in torchjd.aggregation)</a>
515+
<li><a href="../docs/aggregation/#torchjd.aggregation.Stateful">Stateful (class in torchjd.aggregation)</a>
516516
</li>
517517
</ul></td>
518518
<td style="width: 33%; vertical-align: top;"><ul>
519+
<li><a href="../docs/aggregation/sum/#torchjd.aggregation.Sum">Sum (class in torchjd.aggregation)</a>
520+
</li>
519521
<li><a href="../docs/aggregation/sum/#torchjd.aggregation.SumWeighting">SumWeighting (class in torchjd.aggregation)</a>
520522
</li>
521523
</ul></td>

latest/objects.inv

30 Bytes
Binary file not shown.

latest/searchindex.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)