Skip to content

Commit 6d7de2c

Browse files
committed
1 parent 769ba12 commit 6d7de2c

20 files changed

Lines changed: 966 additions & 36 deletions

File tree

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,38 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="aligned-mtl">
241-
<h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this heading"></a></h1>
240+
<section id="module-torchjd.aggregation.aligned_mtl">
241+
<span id="aligned-mtl"></span><h1>Aligned-MTL<a class="headerlink" href="#module-torchjd.aggregation.aligned_mtl" title="Link to this heading"></a></h1>
242+
<dl class="py class">
243+
<dt class="sig sig-object py" id="torchjd.aggregation.aligned_mtl.AlignedMTL">
244+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.aligned_mtl.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/aligned_mtl.py#L37-L74"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.aligned_mtl.AlignedMTL" title="Link to this definition"></a></dt>
245+
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
246+
<a class="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
247+
<dl class="field-list simple">
248+
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
249+
<dd class="field-odd"><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector to use.</p>
250+
</dd>
251+
</dl>
252+
<div class="admonition-example admonition">
253+
<p class="admonition-title">Example</p>
254+
<p>Use AlignedMTL to aggregate a matrix.</p>
255+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
256+
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">AlignedMTL</span>
257+
<span class="gp">&gt;&gt;&gt;</span>
258+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">AlignedMTL</span><span class="p">()</span>
259+
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
260+
<span class="gp">&gt;&gt;&gt;</span>
261+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
262+
<span class="go">tensor([0.2133, 0.9673, 0.9673])</span>
263+
</pre></div>
264+
</div>
265+
</div>
266+
<div class="admonition note">
267+
<p class="admonition-title">Note</p>
268+
<p>This implementation was adapted from the <a class="reference external" href="https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned">official implementation</a>.</p>
269+
</div>
270+
</dd></dl>
271+
242272
</section>
243273

244274
</article>

latest/docs/aggregation/cagrad/index.html

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,42 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="cagrad">
241-
<h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading"></a></h1>
240+
<section id="module-torchjd.aggregation.cagrad">
241+
<span id="cagrad"></span><h1>CAGrad<a class="headerlink" href="#module-torchjd.aggregation.cagrad" title="Link to this heading"></a></h1>
242+
<dl class="py class">
243+
<dt class="sig sig-object py" id="torchjd.aggregation.cagrad.CAGrad">
244+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.cagrad.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/cagrad.py#L10-L48"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.cagrad.CAGrad" title="Link to this definition"></a></dt>
245+
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
246+
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
247+
<dl class="field-list simple">
248+
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
249+
<dd class="field-odd"><ul class="simple">
250+
<li><p><strong>c</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – The scale of the radius of the ball constraint.</p></li>
251+
<li><p><strong>norm_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li>
252+
</ul>
253+
</dd>
254+
</dl>
255+
<div class="admonition-example admonition">
256+
<p class="admonition-title">Example</p>
257+
<p>Use CAGrad to aggregate a matrix.</p>
258+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
259+
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">CAGrad</span>
260+
<span class="gp">&gt;&gt;&gt;</span>
261+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">CAGrad</span><span class="p">(</span><span class="n">c</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
262+
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
263+
<span class="gp">&gt;&gt;&gt;</span>
264+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
265+
<span class="go">tensor([0.1835, 1.2041, 1.2041])</span>
266+
</pre></div>
267+
</div>
268+
</div>
269+
<div class="admonition note">
270+
<p class="admonition-title">Note</p>
271+
<p>This aggregator has dependencies that are not included by default when installing
272+
<code class="docutils literal notranslate"><span class="pre">torchjd</span></code>. To install them, use <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span> <span class="pre">torchjd[cagrad]</span></code>.</p>
273+
</div>
274+
</dd></dl>
275+
242276
</section>
243277

244278
</article>

latest/docs/aggregation/config/index.html

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,39 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="config">
241-
<h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading"></a></h1>
240+
<section id="module-torchjd.aggregation.config">
241+
<span id="config"></span><h1>ConFIG<a class="headerlink" href="#module-torchjd.aggregation.config" title="Link to this heading"></a></h1>
242+
<dl class="py class">
243+
<dt class="sig sig-object py" id="torchjd.aggregation.config.ConFIG">
244+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.config.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/config.py#L36-L83"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.config.ConFIG" title="Link to this definition"></a></dt>
245+
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG: Towards
246+
Conflict-free Training of Physics Informed Neural Networks</a>.</p>
247+
<dl class="field-list simple">
248+
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
249+
<dd class="field-odd"><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector used to weight the rows. If not provided, defaults to
250+
equal weights of 1.</p>
251+
</dd>
252+
</dl>
253+
<div class="admonition-example admonition">
254+
<p class="admonition-title">Example</p>
255+
<p>Use ConFIG to aggregate a matrix.</p>
256+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
257+
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">ConFIG</span>
258+
<span class="gp">&gt;&gt;&gt;</span>
259+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">ConFIG</span><span class="p">()</span>
260+
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
261+
<span class="gp">&gt;&gt;&gt;</span>
262+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
263+
<span class="go">tensor([0.1588, 2.0706, 2.0706])</span>
264+
</pre></div>
265+
</div>
266+
</div>
267+
<div class="admonition note">
268+
<p class="admonition-title">Note</p>
269+
<p>This implementation was adapted from the <a class="reference external" href="https://github.com/tum-pbs/ConFIG/tree/main/conflictfree">official implementation</a>.</p>
270+
</div>
271+
</dd></dl>
272+
242273
</section>
243274

244275
</article>

latest/docs/aggregation/dualproj/index.html

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,44 @@
237237
</label>
238238
</div>
239239
<article role="main">
240-
<section id="dualproj">
241-
<h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading"></a></h1>
240+
<section id="module-torchjd.aggregation.dualproj">
241+
<span id="dualproj"></span><h1>DualProj<a class="headerlink" href="#module-torchjd.aggregation.dualproj" title="Link to this heading"></a></h1>
242+
<dl class="py class">
243+
<dt class="sig sig-object py" id="torchjd.aggregation.dualproj.DualProj">
244+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.dualproj.</span></span><span class="sig-name descname"><span class="pre">DualProj</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/dualproj.py#L12-L67"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.dualproj.DualProj" title="Link to this definition"></a></dt>
245+
<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> that averages the rows of the input matrix, and
246+
projects the result onto the dual cone of the rows of the matrix. This corresponds to the
247+
solution to Equation 11 of <a class="reference external" href="https://proceedings.neurips.cc/paper/2017/file/f87522788a2be2d171666752f97ddebb-Paper.pdf">Gradient Episodic Memory for Continual Learning</a>.</p>
248+
<dl class="field-list simple">
249+
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
250+
<dd class="field-odd"><ul class="simple">
251+
<li><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector used to combine the rows. If not provided, defaults to
252+
the simple averaging.</p></li>
253+
<li><p><strong>norm_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li>
254+
<li><p><strong>reg_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to add to the diagonal of the gramian of the matrix. Due to
255+
numerical errors when computing the gramian, it might not exactly be positive definite.
256+
This issue can make the optimization fail. Adding <code class="docutils literal notranslate"><span class="pre">reg_eps</span></code> to the diagonal of the gramian
257+
ensures that it is positive definite.</p></li>
258+
<li><p><strong>solver</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Literal" title="(in Python v3.13)"><code class="xref py py-data docutils literal notranslate"><span class="pre">Literal</span></code></a>[<code class="docutils literal notranslate"><span class="pre">'quadprog'</span></code>]</span>) – The solver used to optimize the underlying optimization problem.</p></li>
259+
</ul>
260+
</dd>
261+
</dl>
262+
<div class="admonition-example admonition">
263+
<p class="admonition-title">Example</p>
264+
<p>Use DualProj to aggregate a matrix.</p>
265+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span>
266+
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">DualProj</span>
267+
<span class="gp">&gt;&gt;&gt;</span>
268+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span> <span class="o">=</span> <span class="n">DualProj</span><span class="p">()</span>
269+
<span class="gp">&gt;&gt;&gt; </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span>
270+
<span class="gp">&gt;&gt;&gt;</span>
271+
<span class="gp">&gt;&gt;&gt; </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span>
272+
<span class="go">tensor([0.5563, 1.1109, 1.1109])</span>
273+
</pre></div>
274+
</div>
275+
</div>
276+
</dd></dl>
277+
242278
</section>
243279

244280
</article>

0 commit comments

Comments
 (0)