Skip to content

Commit c0057c2

Browse files
committed
1 parent 138ace4 commit c0057c2

11 files changed

Lines changed: 32 additions & 54 deletions

File tree

latest/docs/aggregation/cagrad/index.html

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@
302302
<h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading"></a></h1>
303303
<dl class="py class">
304304
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad">
305-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L138"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
305+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L137"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
306306
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
307307
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
308308
<dl class="field-list simple">
@@ -315,14 +315,13 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
315315
</dl>
316316
<div class="admonition note">
317317
<p class="admonition-title">Note</p>
318-
<p>This aggregator is not installed by default. When not installed, trying to import it should
319-
result in the following error:
320-
<code class="docutils literal notranslate"><span class="pre">ImportError:</span> <span class="pre">cannot</span> <span class="pre">import</span> <span class="pre">name</span> <span class="pre">'CAGrad'</span> <span class="pre">from</span> <span class="pre">'torchjd.aggregation'</span></code>.
321-
To install it, use <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span> <span class="pre">&quot;torchjd[cagrad]&quot;</span></code>.</p>
318+
<p>This aggregator requires optional dependencies. When they are not installed, instantiating
319+
it raises an <a class="reference external" href="https://docs.python.org/3/library/exceptions.html#ImportError" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ImportError</span></code></a> with installation instructions.
320+
To install them, use <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span> <span class="pre">&quot;torchjd[cagrad]&quot;</span></code>.</p>
322321
</div>
323322
<dl class="py method">
324323
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad.__call__">
325-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition"></a></dt>
324+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition"></a></dt>
326325
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
327326
<dl class="field-list simple">
328327
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -338,34 +337,14 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
338337

339338
<dl class="py class">
340339
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting">
341-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L22-L93"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
342-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> [<a class="reference internal" href="../../linalg/psd_matrix/#torchjd.linalg.PSDMatrix" title="torchjd.linalg.PSDMatrix"><code class="xref py py-class docutils literal notranslate"><span class="pre">PSDMatrix</span></code></a>]
343-
giving the weights of <a class="reference internal" href="#torchjd.aggregation.CAGrad" title="torchjd.aggregation.CAGrad"><code class="xref py py-class docutils literal notranslate"><span class="pre">CAGrad</span></code></a>.</p>
344-
<dl class="field-list simple">
345-
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
346-
<dd class="field-odd"><ul class="simple">
347-
<li><p><strong>c</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – The scale of the radius of the ball constraint.</p></li>
348-
<li><p><strong>norm_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li>
349-
</ul>
350-
</dd>
351-
</dl>
352-
<div class="admonition note">
353-
<p class="admonition-title">Note</p>
354-
<p>This implementation differs from the <a class="reference external" href="https://github.com/Cranial-XIX/CAGrad/">official implementations</a> in the way the underlying optimization problem is
355-
solved. This uses the <a class="reference external" href="https://oxfordcontrol.github.io/ClarabelDocs/stable/">CLARABEL</a>
356-
solver of <a class="reference external" href="https://www.cvxpy.org/index.html">cvxpy</a> rather than the <a class="reference external" href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html">scipy.minimize</a>
357-
function.</p>
358-
</div>
359-
<dl class="py method">
340+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L20-L93"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
341+
<dd><dl class="py method">
360342
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting.__call__">
361-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition"></a></dt>
362-
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
343+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition"></a></dt>
344+
<dd><p>Call self as a function.</p>
363345
<dl class="field-list simple">
364-
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
365-
<dd class="field-odd"><p><strong>gramian</strong> – The Gramian from which the weights must be extracted.</p>
366-
</dd>
367-
<dt class="field-even">Return type<span class="colon">:</span></dt>
368-
<dd class="field-even"><p><span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.14)"><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></a></span></p>
346+
<dt class="field-odd">Return type<span class="colon">:</span></dt>
347+
<dd class="field-odd"><p><span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.14)"><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></a></span></p>
369348
</dd>
370349
</dl>
371350
</dd></dl>

latest/docs/aggregation/config/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ <h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading">¶</
317317
</div>
318318
<dl class="py method">
319319
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG.__call__">
320-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition"></a></dt>
320+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition"></a></dt>
321321
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
322322
<dl class="field-list simple">
323323
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/dualproj/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ <h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">
322322
</dl>
323323
<dl class="py method">
324324
<dt class="sig sig-object py" id="torchjd.aggregation.DualProj.__call__">
325-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition"></a></dt>
325+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition"></a></dt>
326326
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
327327
<dl class="field-list simple">
328328
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -357,7 +357,7 @@ <h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">
357357
</dl>
358358
<dl class="py method">
359359
<dt class="sig sig-object py" id="torchjd.aggregation.DualProjWeighting.__call__">
360-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition"></a></dt>
360+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition"></a></dt>
361361
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
362362
<dl class="field-list simple">
363363
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/graddrop/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ <h1>GradDrop<a class="headerlink" href="#graddrop" title="Link to this heading">
318318
</dl>
319319
<dl class="py method">
320320
<dt class="sig sig-object py" id="torchjd.aggregation.GradDrop.__call__">
321-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition"></a></dt>
321+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition"></a></dt>
322322
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
323323
<dl class="field-list simple">
324324
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

0 commit comments

Comments
 (0)