Skip to content

Commit 2700737

Browse files
committed
1 parent c4ee335 commit 2700737

9 files changed

Lines changed: 16 additions & 16 deletions

File tree

latest/docs/aggregation/cagrad/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@
294294
<h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading"></a></h1>
295295
<dl class="py class">
296296
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad">
297-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L20-L49"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
297+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L21-L50"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
298298
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
299299
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
300300
<dl class="field-list simple">
@@ -316,7 +316,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
316316

317317
<dl class="py class">
318318
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting">
319-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L52-L105"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
319+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L53-L106"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
320320
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
321321
<a class="reference internal" href="#torchjd.aggregation.CAGrad" title="torchjd.aggregation.CAGrad"><code class="xref py py-class docutils literal notranslate"><span class="pre">CAGrad</span></code></a>.</p>
322322
<dl class="field-list simple">

latest/docs/aggregation/config/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@
294294
<h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading"></a></h1>
295295
<dl class="py class">
296296
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG">
297-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_config.py#L37-L74"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
297+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_config.py#L39-L76"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
298298
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
299299
Towards Conflict-free Training of Physics Informed Neural Networks</a>.</p>
300300
<dl class="field-list simple">

latest/docs/aggregation/dualproj/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@
294294
<h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading"></a></h1>
295295
<dl class="py class">
296296
<dt class="sig sig-object py" id="torchjd.aggregation.DualProj">
297-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">DualProj</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_dualproj.py#L16-L59"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProj" title="Link to this definition"></a></dt>
297+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">DualProj</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_dualproj.py#L15-L58"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProj" title="Link to this definition"></a></dt>
298298
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> that averages the rows of the input
299299
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
300300
to the solution to Equation 11 of <a class="reference external" href="https://proceedings.neurips.cc/paper/2017/file/f87522788a2be2d171666752f97ddebb-Paper.pdf">Gradient Episodic Memory for Continual Learning</a>.</p>
@@ -316,7 +316,7 @@ <h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">
316316

317317
<dl class="py class">
318318
<dt class="sig sig-object py" id="torchjd.aggregation.DualProjWeighting">
319-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">DualProjWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_dualproj.py#L62-L95"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProjWeighting" title="Link to this definition"></a></dt>
319+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">DualProjWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_dualproj.py#L61-L94"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.DualProjWeighting" title="Link to this definition"></a></dt>
320320
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
321321
<a class="reference internal" href="#torchjd.aggregation.DualProj" title="torchjd.aggregation.DualProj"><code class="xref py py-class docutils literal notranslate"><span class="pre">DualProj</span></code></a>.</p>
322322
<dl class="field-list simple">

latest/docs/aggregation/flattening/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@
294294
<h1>Flattening<a class="headerlink" href="#flattening" title="Link to this heading"></a></h1>
295295
<dl class="py class">
296296
<dt class="sig sig-object py" id="torchjd.aggregation.Flattening">
297-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">Flattening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weighting</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_flattening.py#L10-L36"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.Flattening" title="Link to this definition"></a></dt>
297+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">Flattening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weighting</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_flattening.py#L8-L33"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.Flattening" title="Link to this definition"></a></dt>
298298
<dd><p><a class="reference internal" href="../#torchjd.aggregation.GeneralizedWeighting" title="torchjd.aggregation._weighting_bases.GeneralizedWeighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">GeneralizedWeighting</span></code></a> flattening the generalized
299299
Gramian into a square matrix, extracting a vector of weights from it using a
300300
<a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a>, and returning the reshaped tensor of
@@ -305,7 +305,7 @@ <h1>Flattening<a class="headerlink" href="#flattening" title="Link to this headi
305305
<code class="docutils literal notranslate"><span class="pre">[2,</span> <span class="pre">3]</span></code>.</p>
306306
<dl class="field-list simple">
307307
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
308-
<dd class="field-odd"><p><strong>weighting</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]</span>) – The weighting to apply to the Gramian matrix.</p>
308+
<dd class="field-odd"><p><strong>weighting</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a></span>) – The weighting to apply to the Gramian matrix.</p>
309309
</dd>
310310
</dl>
311311
</dd></dl>

0 commit comments

Comments
 (0)