Skip to content

Commit b3b7185

Browse files
committed
1 parent 4fdca6f commit b3b7185

60 files changed

Lines changed: 2504 additions & 67 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
Constant
4+
========
5+
6+
.. autoclass:: torchjd.scalarization.Constant
7+
:members: __call__
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
scalarization
2+
=============
3+
4+
.. automodule:: torchjd.scalarization
5+
:no-members:
6+
7+
Abstract base class
8+
-------------------
9+
10+
.. autoclass:: torchjd.scalarization.Scalarizer
11+
:members: __call__
12+
13+
14+
.. toctree::
15+
:hidden:
16+
:maxdepth: 1
17+
18+
constant.rst
19+
mean.rst
20+
random.rst
21+
sum.rst
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
Mean
4+
====
5+
6+
.. autoclass:: torchjd.scalarization.Mean
7+
:members: __call__
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
Random
4+
======
5+
6+
.. autoclass:: torchjd.scalarization.Random
7+
:members: __call__
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
Sum
4+
===
5+
6+
.. autoclass:: torchjd.scalarization.Sum
7+
:members: __call__

latest/_sources/index.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ Jacobian descent is the aggregator, which maps the Jacobian to an optimization s
3131
:doc:`Aggregation <docs/aggregation/index>`, we provide an overview of the various aggregators
3232
available in TorchJD, and their corresponding weightings.
3333

34+
For comparison against simple baselines, the :doc:`Scalarization <docs/scalarization/index>`
35+
package provides scalarizers that combine a tensor of losses into a single scalar loss, allowing
36+
standard gradient descent to be used.
37+
3438
A straightforward application of Jacobian descent is multi-task learning, in which the vector of
3539
per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our
3640
:doc:`MTL example <examples/mtl>`.
@@ -70,4 +74,5 @@ TorchJD is open-source, under MIT License. The source code is available on
7074
docs/autogram/index.rst
7175
docs/autojac/index.rst
7276
docs/aggregation/index.rst
77+
docs/scalarization/index.rst
7378
docs/linalg/index.rst

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,14 @@
256256
<li class="toctree-l2"><a class="reference internal" href="../trimmed_mean/">Trimmed Mean</a></li>
257257
</ul>
258258
</li>
259-
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
259+
<li class="toctree-l1 has-children"><a class="reference internal" href="../../scalarization/">scalarization</a><input aria-label="Toggle navigation of scalarization" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
260+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/constant/">Constant</a></li>
261+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/mean/">Mean</a></li>
262+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
263+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
264+
</ul>
265+
</li>
266+
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
260267
<li class="toctree-l2"><a class="reference internal" href="../../linalg/matrix/">Matrix</a></li>
261268
<li class="toctree-l2"><a class="reference internal" href="../../linalg/psd_matrix/">PSDMatrix</a></li>
262269
<li class="toctree-l2"><a class="reference internal" href="../../linalg/dual_cone/">Dual Cone Projectors</a></li>
@@ -304,7 +311,7 @@
304311
<h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this heading"></a></h1>
305312
<dl class="py class">
306313
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTL">
307-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L90-L139"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
314+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L91-L140"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
308315
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
309316
<a class="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
310317
<dl class="field-list simple">
@@ -341,7 +348,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
341348

342349
<dl class="py class">
343350
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTLWeighting">
344-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L20-L87"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
351+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L21-L88"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
345352
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> [<a class="reference internal" href="../../linalg/psd_matrix/#torchjd.linalg.PSDMatrix" title="torchjd.linalg.PSDMatrix"><code class="xref py py-class docutils literal notranslate"><span class="pre">PSDMatrix</span></code></a>]
346353
giving the weights of <a class="reference internal" href="#torchjd.aggregation.AlignedMTL" title="torchjd.aggregation.AlignedMTL"><code class="xref py py-class docutils literal notranslate"><span class="pre">AlignedMTL</span></code></a>.</p>
347354
<dl class="field-list simple">

latest/docs/aggregation/cagrad/index.html

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,14 @@
256256
<li class="toctree-l2"><a class="reference internal" href="../trimmed_mean/">Trimmed Mean</a></li>
257257
</ul>
258258
</li>
259-
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
259+
<li class="toctree-l1 has-children"><a class="reference internal" href="../../scalarization/">scalarization</a><input aria-label="Toggle navigation of scalarization" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
260+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/constant/">Constant</a></li>
261+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/mean/">Mean</a></li>
262+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
263+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
264+
</ul>
265+
</li>
266+
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
260267
<li class="toctree-l2"><a class="reference internal" href="../../linalg/matrix/">Matrix</a></li>
261268
<li class="toctree-l2"><a class="reference internal" href="../../linalg/psd_matrix/">PSDMatrix</a></li>
262269
<li class="toctree-l2"><a class="reference internal" href="../../linalg/dual_cone/">Dual Cone Projectors</a></li>

latest/docs/aggregation/config/index.html

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,14 @@
256256
<li class="toctree-l2"><a class="reference internal" href="../trimmed_mean/">Trimmed Mean</a></li>
257257
</ul>
258258
</li>
259-
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
259+
<li class="toctree-l1 has-children"><a class="reference internal" href="../../scalarization/">scalarization</a><input aria-label="Toggle navigation of scalarization" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" role="switch" type="checkbox"/><label for="toctree-checkbox-5"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
260+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/constant/">Constant</a></li>
261+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/mean/">Mean</a></li>
262+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/random/">Random</a></li>
263+
<li class="toctree-l2"><a class="reference internal" href="../../scalarization/sum/">Sum</a></li>
264+
</ul>
265+
</li>
266+
<li class="toctree-l1 has-children"><a class="reference internal" href="../../linalg/">linalg</a><input aria-label="Toggle navigation of linalg" class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" role="switch" type="checkbox"/><label for="toctree-checkbox-6"><span class="icon"><svg><use href="#svg-arrow-right"></use></svg></span></label><ul>
260267
<li class="toctree-l2"><a class="reference internal" href="../../linalg/matrix/">Matrix</a></li>
261268
<li class="toctree-l2"><a class="reference internal" href="../../linalg/psd_matrix/">PSDMatrix</a></li>
262269
<li class="toctree-l2"><a class="reference internal" href="../../linalg/dual_cone/">Dual Cone Projectors</a></li>
@@ -304,7 +311,7 @@
304311
<h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading"></a></h1>
305312
<dl class="py class">
306313
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG">
307-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_config.py#L17-L59"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
314+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_config.py#L18-L60"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
308315
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
309316
Towards Conflict-free Training of Physics Informed Neural Networks</a>.</p>
310317
<dl class="field-list simple">

0 commit comments

Comments
 (0)