You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L137"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
307
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L97-L138"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
308
308
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
309
309
<aclass="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
310
310
<dlclass="field-list simple">
@@ -323,7 +323,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition">¶</a></dt>
326
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition">¶</a></dt>
327
327
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition">¶</a></dt>
345
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition">¶</a></dt>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition">¶</a></dt>
322
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition">¶</a></dt>
323
323
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition">¶</a></dt>
322
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProj.__call__" title="Link to this definition">¶</a></dt>
323
323
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition">¶</a></dt>
352
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProjWeighting.__call__" title="Link to this definition">¶</a></dt>
353
353
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition">¶</a></dt>
323
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradDrop.__call__" title="Link to this definition">¶</a></dt>
324
324
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVac.__call__" title="Link to this definition">¶</a></dt>
343
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVac.__call__" title="Link to this definition">¶</a></dt>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L54-L56"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVacWeighting.__call__" title="Link to this definition">¶</a></dt>
405
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.GradVacWeighting.__call__" title="Link to this definition">¶</a></dt>
0 commit comments