You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L20-L49"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
297
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L21-L50"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
298
298
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
299
299
<aclass="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
300
300
<dlclass="field-list simple">
@@ -316,7 +316,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">ConFIG</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_config.py#L37-L74"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition">¶</a></dt>
297
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">ConFIG</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_config.py#L39-L76"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition">¶</a></dt>
298
298
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> as defined in Equation 2 of <aclass="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
299
299
Towards Conflict-free Training of Physics Informed Neural Networks</a>.</p>
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">DualProj</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">reg_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">solver</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">'quadprog'</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_dualproj.py#L16-L59"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProj" title="Link to this definition">¶</a></dt>
297
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">DualProj</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">reg_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">solver</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">'quadprog'</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_dualproj.py#L15-L58"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.DualProj" title="Link to this definition">¶</a></dt>
298
298
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> that averages the rows of the input
299
299
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
300
300
to the solution to Equation 11 of <aclass="reference external" href="https://proceedings.neurips.cc/paper/2017/file/f87522788a2be2d171666752f97ddebb-Paper.pdf">Gradient Episodic Memory for Continual Learning</a>.</p>
@@ -316,7 +316,7 @@ <h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">Flattening</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">weighting</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_flattening.py#L10-L36"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.Flattening" title="Link to this definition">¶</a></dt>
297
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">Flattening</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">weighting</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_flattening.py#L8-L33"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.Flattening" title="Link to this definition">¶</a></dt>
Gramian into a square matrix, extracting a vector of weights from it using a
300
300
<aclass="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Weighting</span></code></a>, and returning the reshaped tensor of
@@ -305,7 +305,7 @@ <h1>Flattening<a class="headerlink" href="#flattening" title="Link to this headi
<ddclass="field-odd"><p><strong>weighting</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Weighting</span></code></a>[<aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a>]</span>) – The weighting to apply to the Gramian matrix.</p>
308
+
<ddclass="field-odd"><p><strong>weighting</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Weighting</span></code></a></span>) – The weighting to apply to the Gramian matrix.</p>
0 commit comments