Skip to content

Commit f134557

Browse files
committed
1 parent 5a2f382 commit f134557

14 files changed

Lines changed: 63 additions & 41 deletions

File tree

latest/_sources/installation.md.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
Note that `torchjd` requires Python 3.10, 3.11, 3.12, 3.13 or 3.14 and `torch>=2.0`.
99

10-
Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
11-
when installing `torchjd`. To install them, you can use:
12-
```
13-
pip install "torchjd[cagrad]"
14-
```
15-
```
16-
pip install "torchjd[nash_mtl]"
17-
```
10+
Some aggregators have additional dependencies that are not included by default when installing
11+
`torchjd`. The following table lists the optional dependency groups and the aggregators they enable:
12+
13+
Group | Classes | Dependencies | Install command |
14+
|-----|---------|--------------|-----------------|
15+
| `quadprog_projector` | {class}`~torchjd.linalg.QuadprogProjector` (used in {class}`~torchjd.aggregation.UPGrad` and {class}`~torchjd.aggregation.DualProj`) | [numpy](https://github.com/numpy/numpy), [quadprog](https://github.com/quadprog/quadprog), [qpsolvers](https://github.com/qpsolvers/qpsolvers) | `pip install "torchjd[quadprog_projector]"` |
16+
| `cagrad` | {class}`~torchjd.aggregation.CAGrad` | [numpy](https://github.com/numpy/numpy), [cvxpy](https://github.com/cvxpy/cvxpy/) | `pip install "torchjd[cagrad]"` |
17+
| `nash_mtl` | {class}`~torchjd.aggregation.NashMTL` | [numpy](https://github.com/numpy/numpy), [cvxpy](https://github.com/cvxpy/cvxpy/), [ecos](https://github.com/embotech/ecos) | `pip install "torchjd[nash_mtl]"` |
1818

1919
To install `torchjd` with all of its optional dependencies, you can also use:
2020
```

latest/docs/aggregation/cagrad/index.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
<h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading"></a></h1>
305305
<dl class="py class">
306306
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad">
307-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L137"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
307+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L97-L138"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
308308
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
309309
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
310310
<dl class="field-list simple">
@@ -323,7 +323,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
323323
</div>
324324
<dl class="py method">
325325
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad.__call__">
326-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition"></a></dt>
326+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition"></a></dt>
327327
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
328328
<dl class="field-list simple">
329329
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -339,10 +339,10 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
339339

340340
<dl class="py class">
341341
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting">
342-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L20-L93"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
342+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L21-L94"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
343343
<dd><dl class="py method">
344344
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting.__call__">
345-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition"></a></dt>
345+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition"></a></dt>
346346
<dd><p>Call self as a function.</p>
347347
<dl class="field-list simple">
348348
<dt class="field-odd">Return type<span class="colon">:</span></dt>

latest/docs/aggregation/config/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ <h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading">¶</
319319
</div>
320320
<dl class="py method">
321321
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG.__call__">
322-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition"></a></dt>
322+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition"></a></dt>
323323
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
324324
<dl class="field-list simple">
325325
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/dualproj/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ <h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">
319319
</dl>
320320
<dl class="py method">
321321
<dt class="sig sig-object py" id="torchjd.aggregation.DualProj.__call__">
322-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition"></a></dt>
322+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition"></a></dt>
323323
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
324324
<dl class="field-list simple">
325325
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -349,7 +349,7 @@ <h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">
349349
</dl>
350350
<dl class="py method">
351351
<dt class="sig sig-object py" id="torchjd.aggregation.DualProjWeighting.__call__">
352-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition"></a></dt>
352+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition"></a></dt>
353353
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
354354
<dl class="field-list simple">
355355
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/graddrop/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ <h1>GradDrop<a class="headerlink" href="#graddrop" title="Link to this heading">
320320
</dl>
321321
<dl class="py method">
322322
<dt class="sig sig-object py" id="torchjd.aggregation.GradDrop.__call__">
323-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition"></a></dt>
323+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition"></a></dt>
324324
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
325325
<dl class="field-list simple">
326326
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/gradvac/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ <h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading">¶
340340
</div>
341341
<dl class="py method">
342342
<dt class="sig sig-object py" id="torchjd.aggregation.GradVac.__call__">
343-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVac.__call__" title="Link to this definition"></a></dt>
343+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVac.__call__" title="Link to this definition"></a></dt>
344344
<dd><p>Call self as a function.</p>
345345
<dl class="field-list simple">
346346
<dt class="field-odd">Return type<span class="colon">:</span></dt>
@@ -402,7 +402,7 @@ <h1>GradVac<a class="headerlink" href="#gradvac" title="Link to this heading">¶
402402

403403
<dl class="py method">
404404
<dt class="sig sig-object py" id="torchjd.aggregation.GradVacWeighting.__call__">
405-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVacWeighting.__call__" title="Link to this definition"></a></dt>
405+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.GradVacWeighting.__call__" title="Link to this definition"></a></dt>
406406
<dd><p>Call self as a function.</p>
407407
<dl class="field-list simple">
408408
<dt class="field-odd">Return type<span class="colon">:</span></dt>

0 commit comments

Comments
 (0)