Skip to content

Commit 769ba12

Browse files
committed
1 parent cdf2e9a commit 769ba12

22 files changed

Lines changed: 62 additions & 956 deletions

File tree

latest/_sources/installation.md.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,17 @@
66
```
77

88
Note that `torchjd` requires python 3.10, 3.11, 3.12 or 3.13 and `torch>=2.0`.
9+
10+
Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
11+
when installing `torchjd`. To install them, you can use:
12+
```
13+
pip install torchjd[cagrad]
14+
```
15+
```
16+
pip install torchjd[nash_mtl]
17+
```
18+
19+
To install `torchjd` with all of its optional dependencies, you can also use:
20+
```
21+
pip install torchjd[full]
22+
```

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -237,38 +237,8 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="module-torchjd.aggregation.aligned_mtl">
241-
<span id="aligned-mtl"></span><h1>Aligned-MTL<a class="headerlink" href="#module-torchjd.aggregation.aligned_mtl" title="Link to this heading"></a></h1>
242-
<dl class="py class">
243-
<dt class="sig sig-object py" id="torchjd.aggregation.aligned_mtl.AlignedMTL">
244-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.aligned_mtl.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/aligned_mtl.py#L37-L74"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.aligned_mtl.AlignedMTL" title="Link to this definition"></a></dt>
245-
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
246-
<a class="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
247-
<dl class="field-list simple">
248-
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
249-
<dd class="field-odd"><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector to use.</p>
250-
</dd>
251-
</dl>
252-
<div class="admonition-example admonition">
253-
<p class="admonition-title">Example</p>
254-
<p>Use AlignedMTL to aggregate a matrix.</p>
255-
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
256-
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">AlignedMTL</span>
257-
<span class="gp">&gt;&gt;&gt;</span>
258-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">AlignedMTL</span><span class="p">()</span>
259-
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
260-
<span class="gp">&gt;&gt;&gt;</span>
261-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
262-
<span class="go">tensor([0.2133, 0.9673, 0.9673])</span>
263-
</pre></div>
264-
</div>
265-
</div>
266-
<div class="admonition note">
267-
<p class="admonition-title">Note</p>
268-
<p>This implementation was adapted from the <a class="reference external" href="https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned">official implementation</a>.</p>
269-
</div>
270-
</dd></dl>
271-
240+
<section id="aligned-mtl">
241+
<h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this heading"></a></h1>
272242
</section>
273243

274244
</article>

latest/docs/aggregation/cagrad/index.html

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -237,37 +237,8 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="module-torchjd.aggregation.cagrad">
241-
<span id="cagrad"></span><h1>CAGrad<a class="headerlink" href="#module-torchjd.aggregation.cagrad" title="Link to this heading"></a></h1>
242-
<dl class="py class">
243-
<dt class="sig sig-object py" id="torchjd.aggregation.cagrad.CAGrad">
244-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.cagrad.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/cagrad.py#L10-L44"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.cagrad.CAGrad" title="Link to this definition"></a></dt>
245-
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
246-
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
247-
<dl class="field-list simple">
248-
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
249-
<dd class="field-odd"><ul class="simple">
250-
<li><p><strong>c</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – The scale of the radius of the ball constraint.</p></li>
251-
<li><p><strong>norm_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li>
252-
</ul>
253-
</dd>
254-
</dl>
255-
<div class="admonition-example admonition">
256-
<p class="admonition-title">Example</p>
257-
<p>Use CAGrad to aggregate a matrix.</p>
258-
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
259-
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">CAGrad</span>
260-
<span class="gp">&gt;&gt;&gt;</span>
261-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">CAGrad</span><span class="p">(</span><span class="n">c</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
262-
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
263-
<span class="gp">&gt;&gt;&gt;</span>
264-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
265-
<span class="go">tensor([0.1835, 1.2041, 1.2041])</span>
266-
</pre></div>
267-
</div>
268-
</div>
269-
</dd></dl>
270-
240+
<section id="cagrad">
241+
<h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading"></a></h1>
271242
</section>
272243

273244
</article>

latest/docs/aggregation/config/index.html

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -237,39 +237,8 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="module-torchjd.aggregation.config">
241-
<span id="config"></span><h1>ConFIG<a class="headerlink" href="#module-torchjd.aggregation.config" title="Link to this heading"></a></h1>
242-
<dl class="py class">
243-
<dt class="sig sig-object py" id="torchjd.aggregation.config.ConFIG">
244-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.config.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/config.py#L36-L83"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.config.ConFIG" title="Link to this definition"></a></dt>
245-
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG: Towards
246-
Conflict-free Training of Physics Informed Neural Networks</a>.</p>
247-
<dl class="field-list simple">
248-
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
249-
<dd class="field-odd"><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector used to weight the rows. If not provided, defaults to
250-
equal weights of 1.</p>
251-
</dd>
252-
</dl>
253-
<div class="admonition-example admonition">
254-
<p class="admonition-title">Example</p>
255-
<p>Use ConFIG to aggregate a matrix.</p>
256-
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
257-
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">ConFIG</span>
258-
<span class="gp">&gt;&gt;&gt;</span>
259-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">ConFIG</span><span class="p">()</span>
260-
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
261-
<span class="gp">&gt;&gt;&gt;</span>
262-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
263-
<span class="go">tensor([0.1588, 2.0706, 2.0706])</span>
264-
</pre></div>
265-
</div>
266-
</div>
267-
<div class="admonition note">
268-
<p class="admonition-title">Note</p>
269-
<p>This implementation was adapted from the <a class="reference external" href="https://github.com/tum-pbs/ConFIG/tree/main/conflictfree">official implementation</a>.</p>
270-
</div>
271-
</dd></dl>
272-
240+
<section id="config">
241+
<h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading"></a></h1>
273242
</section>
274243

275244
</article>

latest/docs/aggregation/dualproj/index.html

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -237,44 +237,8 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="module-torchjd.aggregation.dualproj">
241-
<span id="dualproj"></span><h1>DualProj<a class="headerlink" href="#module-torchjd.aggregation.dualproj" title="Link to this heading"></a></h1>
242-
<dl class="py class">
243-
<dt class="sig sig-object py" id="torchjd.aggregation.dualproj.DualProj">
244-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.dualproj.</span></span><span class="sig-name descname"><span class="pre">DualProj</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/dualproj.py#L12-L67"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.dualproj.DualProj" title="Link to this definition"></a></dt>
245-
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> that averages the rows of the input matrix, and
246-
projects the result onto the dual cone of the rows of the matrix. This corresponds to the
247-
solution to Equation 11 of <a class="reference external" href="https://proceedings.neurips.cc/paper/2017/file/f87522788a2be2d171666752f97ddebb-Paper.pdf">Gradient Episodic Memory for Continual Learning</a>.</p>
248-
<dl class="field-list simple">
249-
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
250-
<dd class="field-odd"><ul class="simple">
251-
<li><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector used to combine the rows. If not provided, defaults to
252-
the simple averaging.</p></li>
253-
<li><p><strong>norm_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li>
254-
<li><p><strong>reg_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to add to the diagonal of the gramian of the matrix. Due to
255-
numerical errors when computing the gramian, it might not exactly be positive definite.
256-
This issue can make the optimization fail. Adding <code class="docutils literal notranslate"><span class="pre">reg_eps</span></code> to the diagonal of the gramian
257-
ensures that it is positive definite.</p></li>
258-
<li><p><strong>solver</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Literal" title="(in Python v3.13)"><code class="xref py py-data docutils literal notranslate"><span class="pre">Literal</span></code></a>[<code class="docutils literal notranslate"><span class="pre">'quadprog'</span></code>]</span>) – The solver used to optimize the underlying optimization problem.</p></li>
259-
</ul>
260-
</dd>
261-
</dl>
262-
<div class="admonition-example admonition">
263-
<p class="admonition-title">Example</p>
264-
<p>Use DualProj to aggregate a matrix.</p>
265-
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
266-
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">DualProj</span>
267-
<span class="gp">&gt;&gt;&gt;</span>
268-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">DualProj</span><span class="p">()</span>
269-
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
270-
<span class="gp">&gt;&gt;&gt;</span>
271-
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
272-
<span class="go">tensor([0.5563, 1.1109, 1.1109])</span>
273-
</pre></div>
274-
</div>
275-
</div>
276-
</dd></dl>
277-
240+
<section id="dualproj">
241+
<h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading"></a></h1>
278242
</section>
279243

280244
</article>

0 commit comments

Comments
 (0)