Skip to content

Commit 65baea0

Browse files
committed
1 parent 1848568 commit 65baea0

64 files changed

Lines changed: 944 additions & 52 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

latest/_sources/docs/aggregation/index.rst.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ Abstract base classes
1919
.. autoclass:: torchjd.aggregation.Weighting
2020
:members: __call__
2121

22-
.. autoclass:: torchjd.aggregation.Stateful
23-
:members: reset
24-
25-
2622
.. toctree::
2723
:hidden:
2824
:maxdepth: 1

latest/_sources/docs/scalarization/index.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ Abstract base class
1010
.. autoclass:: torchjd.scalarization.Scalarizer
1111
:members: __call__
1212

13-
1413
.. toctree::
1514
:hidden:
1615
:maxdepth: 1
@@ -21,3 +20,4 @@ Abstract base class
2120
random.rst
2221
stch.rst
2322
sum.rst
23+
uw.rst
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
UW
4+
==
5+
6+
.. autoclass:: torchjd.scalarization.UW
7+
:members: __call__, reset
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
:orphan:
2+
3+
.. autoclass:: torchjd.Stateful
4+
:members: reset

latest/_sources/examples/grouping.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ the parameters:
1919

2020
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
2121
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
22-
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
23-
should independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in
22+
aggregator instance per group. For :class:`~torchjd.Stateful` aggregators, each instance should
23+
independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in
2424
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).
2525

2626
.. note::

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
265265
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/stch/">STCH</a></li>
266266
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
267+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/uw/">UW</a></li>
267268
</ul>
268269
</li>
269270
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>

latest/docs/aggregation/cagrad/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
265265
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/stch/">STCH</a></li>
266266
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
267+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/uw/">UW</a></li>
267268
</ul>
268269
</li>
269270
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>

latest/docs/aggregation/config/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
265265
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/stch/">STCH</a></li>
266266
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
267+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/uw/">UW</a></li>
267268
</ul>
268269
</li>
269270
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>

latest/docs/aggregation/constant/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
265265
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/stch/">STCH</a></li>
266266
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
267+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/uw/">UW</a></li>
267268
</ul>
268269
</li>
269270
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>

latest/docs/aggregation/cr_mogm/index.html

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
265265
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/stch/">STCH</a></li>
266266
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
267+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/uw/">UW</a></li>
267268
</ul>
268269
</li>
269270
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
@@ -315,7 +316,7 @@ <h1>CR-MOGM<a class="headerlink" href="#cr-mogm" title="Link to this heading">¶
315316
<dl class="py class">
316317
<dt class="sig sig-object py" id="torchjd.aggregation.CRMOGMWeighting">
317318
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CRMOGMWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weighting</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">alpha</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.9</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">initial_weights</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cr_mogm.py#L14-L155"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CRMOGMWeighting" title="Link to this definition"></a></dt>
318-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Stateful" title="torchjd.aggregation._mixins.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>
319+
<dd><p><a class="reference internal" href="../../stateful/#torchjd.Stateful" title="torchjd.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>
319320
<a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> that wraps another
320321
<a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> and stabilises the weights it
321322
produces with an exponential moving average (EMA) across calls. This is the weight-smoothing
@@ -354,7 +355,7 @@ <h1>CR-MOGM<a class="headerlink" href="#cr-mogm" title="Link to this heading">¶
354355
themselves, so wrapping by <code class="docutils literal notranslate"><span class="pre">CRMOGMWeighting</span></code> will have no effect.</p>
355356
<p>This weighting is stateful: it keeps <span class="math notranslate nohighlight">\(\lambda_{k-1}\)</span> across calls. Use <a class="reference internal" href="#torchjd.aggregation.CRMOGMWeighting.reset" title="torchjd.aggregation.CRMOGMWeighting.reset"><code class="xref py py-meth docutils literal notranslate"><span class="pre">reset()</span></code></a>
356357
to restart the smoothing from the initial state. Note that calling <a class="reference internal" href="#torchjd.aggregation.CRMOGMWeighting.reset" title="torchjd.aggregation.CRMOGMWeighting.reset"><code class="xref py py-meth docutils literal notranslate"><span class="pre">reset()</span></code></a> will also
357-
reset the wrapped weighting if it is <a class="reference internal" href="../#torchjd.aggregation.Stateful" title="torchjd.aggregation.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>.</p>
358+
reset the wrapped weighting if it is <a class="reference internal" href="../../stateful/#torchjd.Stateful" title="torchjd.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>.</p>
358359
<dl class="field-list simple">
359360
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
360361
<dd class="field-odd"><ul class="simple">
@@ -397,7 +398,7 @@ <h1>CR-MOGM<a class="headerlink" href="#cr-mogm" title="Link to this heading">¶
397398
<dt class="sig sig-object py" id="torchjd.aggregation.CRMOGMWeighting.reset">
398399
<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cr_mogm.py#L120-L128"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CRMOGMWeighting.reset" title="Link to this definition"></a></dt>
399400
<dd><p>Clears the EMA state so the next forward restarts from the initial state. Also resets the
400-
wrapped weighting if it is <a class="reference internal" href="../#torchjd.aggregation.Stateful" title="torchjd.aggregation._mixins.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>.</p>
401+
wrapped weighting if it is <a class="reference internal" href="../../stateful/#torchjd.Stateful" title="torchjd.Stateful"><code class="xref py py-class docutils literal notranslate"><span class="pre">Stateful</span></code></a>.</p>
401402
<dl class="field-list simple">
402403
<dt class="field-odd">Return type<span class="colon">:</span></dt>
403404
<dd class="field-odd"><p><span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span></p>

0 commit comments

Comments
 (0)