You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L138"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
305
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L137"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
306
306
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
307
307
<aclass="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
308
308
<dlclass="field-list simple">
@@ -315,14 +315,13 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
315
315
</dl>
316
316
<divclass="admonition note">
317
317
<pclass="admonition-title">Note</p>
318
-
<p>This aggregator is not installed by default. When not installed, trying to import it should
To install it, use <codeclass="docutils literal notranslate"><spanclass="pre">pip</span><spanclass="pre">install</span><spanclass="pre">"torchjd[cagrad]"</span></code>.</p>
318
+
<p>This aggregator requires optional dependencies. When they are not installed, instantiating
319
+
it raises an <aclass="reference external" href="https://docs.python.org/3/library/exceptions.html#ImportError" title="(in Python v3.14)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">ImportError</span></code></a> with installation instructions.
320
+
To install them, use <codeclass="docutils literal notranslate"><spanclass="pre">pip</span><spanclass="pre">install</span><spanclass="pre">"torchjd[cagrad]"</span></code>.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition">¶</a></dt>
324
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition">¶</a></dt>
326
325
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<li><p><strong>c</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.14)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">float</span></code></a></span>) – The scale of the radius of the ball constraint.</p></li>
348
-
<li><p><strong>norm_eps</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.14)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li>
349
-
</ul>
350
-
</dd>
351
-
</dl>
352
-
<divclass="admonition note">
353
-
<pclass="admonition-title">Note</p>
354
-
<p>This implementation differs from the <aclass="reference external" href="https://github.com/Cranial-XIX/CAGrad/">official implementations</a> in the way the underlying optimization problem is
355
-
solved. This uses the <aclass="reference external" href="https://oxfordcontrol.github.io/ClarabelDocs/stable/">CLARABEL</a>
356
-
solver of <aclass="reference external" href="https://www.cvxpy.org/index.html">cvxpy</a> rather than the <aclass="reference external" href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html">scipy.minimize</a>
357
-
function.</p>
358
-
</div>
359
-
<dlclass="py method">
340
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGradWeighting</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L20-L93"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition">¶</a></dt>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition">¶</a></dt>
362
-
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
343
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition">¶</a></dt>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition">¶</a></dt>
320
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition">¶</a></dt>
321
321
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition">¶</a></dt>
325
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition">¶</a></dt>
326
326
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition">¶</a></dt>
360
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition">¶</a></dt>
361
361
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition">¶</a></dt>
321
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition">¶</a></dt>
322
322
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
0 commit comments