Skip to content

Commit 1404079

Browse files
committed
1 parent cb58a64 commit 1404079

23 files changed

Lines changed: 204 additions & 66 deletions

File tree

latest/_sources/docs/aggregation/index.rst.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,21 @@ Abstract base classes
1010
.. autoclass:: torchjd.aggregation.Aggregator
1111
:members: __call__
1212

13+
.. autoclass:: torchjd.aggregation.WeightedAggregator
14+
:members: __call__
15+
16+
.. autoclass:: torchjd.aggregation.GramianWeightedAggregator
17+
:members: __call__
18+
1319
.. autoclass:: torchjd.aggregation.Weighting
1420
:members: __call__
1521

22+
.. autoclass:: torchjd.aggregation.MatrixWeighting
23+
:members: __call__
24+
25+
.. autoclass:: torchjd.aggregation.GramianWeighting
26+
:members: __call__
27+
1628
.. autoclass:: torchjd.aggregation.GeneralizedWeighting
1729
:members: __call__
1830

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
298298
<dl class="py class">
299299
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTL">
300300
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L90-L139"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
301-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
301+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
302302
<a class="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
303303
<dl class="field-list simple">
304304
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -318,7 +318,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
318318
</div>
319319
<dl class="py method">
320320
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTL.__call__">
321-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L31-L38"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL.__call__" title="Link to this definition"></a></dt>
321+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL.__call__" title="Link to this definition"></a></dt>
322322
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
323323
<dl class="field-list simple">
324324
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -335,7 +335,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
335335
<dl class="py class">
336336
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTLWeighting">
337337
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L20-L87"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
338-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
338+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeighting" title="torchjd.aggregation.GramianWeighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeighting</span></code></a> giving the weights of
339339
<a class="reference internal" href="#torchjd.aggregation.AlignedMTL" title="torchjd.aggregation.AlignedMTL"><code class="xref py py-class docutils literal notranslate"><span class="pre">AlignedMTL</span></code></a>.</p>
340340
<dl class="field-list simple">
341341
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/cagrad/index.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
298298
<dl class="py class">
299299
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad">
300300
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L95-L140"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
301-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
301+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
302302
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
303303
<dl class="field-list simple">
304304
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -317,7 +317,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
317317
</div>
318318
<dl class="py method">
319319
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad.__call__">
320-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L31-L38"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition"></a></dt>
320+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition"></a></dt>
321321
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
322322
<dl class="field-list simple">
323323
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -334,7 +334,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
334334
<dl class="py class">
335335
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting">
336336
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L21-L92"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
337-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
337+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GramianWeighting" title="torchjd.aggregation.GramianWeighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">GramianWeighting</span></code></a> giving the weights of
338338
<a class="reference internal" href="#torchjd.aggregation.CAGrad" title="torchjd.aggregation.CAGrad"><code class="xref py py-class docutils literal notranslate"><span class="pre">CAGrad</span></code></a>.</p>
339339
<dl class="field-list simple">
340340
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

latest/docs/aggregation/config/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ <h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading">¶</
298298
<dl class="py class">
299299
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG">
300300
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_config.py#L16-L61"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
301-
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
301+
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
302302
Towards Conflict-free Training of Physics Informed Neural Networks</a>.</p>
303303
<dl class="field-list simple">
304304
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
@@ -312,7 +312,7 @@ <h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading">¶</
312312
</div>
313313
<dl class="py method">
314314
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG.__call__">
315-
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L31-L38"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition"></a></dt>
315+
<span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">matrix</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition"></a></dt>
316316
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
317317
<dl class="field-list simple">
318318
<dt class="field-odd">Parameters<span class="colon">:</span></dt>

0 commit comments

Comments
 (0)