Skip to content

Commit dc5e691

Browse files
committed
1 parent 81d2d92 commit dc5e691

67 files changed

Lines changed: 1642 additions & 214 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stable/_sources/docs/aggregation/index.rst.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ Abstract base classes
1919
.. autoclass:: torchjd.aggregation.Weighting
2020
:members: __call__
2121

22-
.. autoclass:: torchjd.aggregation.Stateful
23-
:members: reset
24-
25-
2622
.. toctree::
2723
:hidden:
2824
:maxdepth: 1
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
IMTL-L
4+
======
5+
6+
.. autoclass:: torchjd.scalarization.IMTLL
7+
:members: __call__, reset

stable/_sources/docs/scalarization/index.rst.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ Abstract base class
1010
.. autoclass:: torchjd.scalarization.Scalarizer
1111
:members: __call__
1212

13-
1413
.. toctree::
1514
:hidden:
1615
:maxdepth: 1
1716

1817
constant.rst
1918
geometric_mean.rst
19+
imtl_l.rst
2020
mean.rst
2121
random.rst
2222
stch.rst
2323
sum.rst
24+
uw.rst
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
UW
4+
==
5+
6+
.. autoclass:: torchjd.scalarization.UW
7+
:members: __call__, reset
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
:orphan:
2+
3+
.. autoclass:: torchjd.Stateful
4+
:members: reset

stable/_sources/examples/grouping.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ the parameters:
1919

2020
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
2121
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
22-
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
23-
should independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in
22+
aggregator instance per group. For :class:`~torchjd.Stateful` aggregators, each instance should
23+
independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in
2424
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).
2525

2626
.. note::

stable/_static/documentation_options.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
const DOCUMENTATION_OPTIONS = {
2-
VERSION: '0.13.0',
2+
VERSION: '0.14.0',
33
LANGUAGE: 'en',
44
COLLAPSE_INDEX: false,
55
BUILDER: 'dirhtml',

stable/docs/aggregation/aligned_mtl/index.html

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,12 @@
260260
<li class="toctree-l1 has-children"><a class="reference internal" href="../../scalarization/">scalarization</a><input aria-label="Toggle navigation of scalarization" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
261261
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/constant/">Constant</a></li>
262262
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/geometric_mean/">GeometricMean</a></li>
263+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/imtl_l/">IMTL-L</a></li>
263264
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/mean/">Mean</a></li>
264265
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
265266
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/stch/">STCH</a></li>
266267
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
268+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/uw/">UW</a></li>
267269
</ul>
268270
</li>
269271
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
@@ -314,7 +316,7 @@
314316
<h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this heading"></a></h1>
315317
<dl class="py class">
316318
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTL">
317-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.13.0/src/torchjd/aggregation/_aligned_mtl.py#L91-L140"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
319+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.14.0/src/torchjd/aggregation/_aligned_mtl.py#L91-L140"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
318320
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
319321
<a class="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
320322
<dl class="field-list simple">
@@ -335,7 +337,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
335337
</div>
336338
<dl class="py method">
337339
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTL.__call__">
338-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.13.0/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL.__call__" title="Link to this definition"></a></dt>
340+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.14.0/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL.__call__" title="Link to this definition"></a></dt>
339341
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
340342
<dl class="field-list simple">
341343
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -351,7 +353,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
351353

352354
<dl class="py class">
353355
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTLWeighting">
354-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.13.0/src/torchjd/aggregation/_aligned_mtl.py#L21-L88"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
356+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.14.0/src/torchjd/aggregation/_aligned_mtl.py#L21-L88"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
355357
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> [<a class="reference internal" href="../../linalg/psd_matrix/#torchjd.linalg.PSDMatrix" title="torchjd.linalg.PSDMatrix"><code class="xref py py-class docutils literal notranslate"><span class="pre">PSDMatrix</span></code></a>]
356358
giving the weights of <a class="reference internal" href="#torchjd.aggregation.AlignedMTL" title="torchjd.aggregation.AlignedMTL"><code class="xref py py-class docutils literal notranslate"><span class="pre">AlignedMTL</span></code></a>.</p>
357359
<dl class="field-list simple">
@@ -367,7 +369,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
367369
</dl>
368370
<dl class="py method">
369371
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTLWeighting.__call__">
370-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">gramian</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.13.0/src/torchjd/aggregation/_weighting_bases.py#L71-L77"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting.__call__" title="Link to this definition"></a></dt>
372+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">gramian</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/v0.14.0/src/torchjd/aggregation/_weighting_bases.py#L71-L77"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting.__call__" title="Link to this definition"></a></dt>
371373
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
372374
<dl class="field-list simple">
373375
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -432,7 +434,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
432434

433435
</aside>
434436
</div>
435-
</div><script src="../../../_static/documentation_options.js?v=2dc2599f"></script>
437+
</div><script src="../../../_static/documentation_options.js?v=47edbadb"></script>
436438
<script src="../../../_static/doctools.js?v=fd6eb6e6"></script>
437439
<script src="../../../_static/sphinx_highlight.js?v=6ffebe34"></script>
438440
<script src="../../../_static/scripts/furo.js?v=46bd48cc"></script>

0 commit comments

Comments
 (0)