You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L95-L140"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
305
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">CAGrad</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">c</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">norm_eps</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0.0001</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L96-L138"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition">¶</a></dt>
306
306
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.GramianWeightedAggregator" title="torchjd.aggregation.GramianWeightedAggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">GramianWeightedAggregator</span></code></a> as defined in Algorithm 1 of
307
307
<aclass="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
308
308
<dlclass="field-list simple">
@@ -322,14 +322,14 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">matrix</span></span></em>, <emclass="sig-param"><spanclass="positional-only-separator o"><abbrtitle="Positional-only parameter separator (PEP 570)"><spanclass="pre">/</span></abbr></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition">¶</a></dt>
325
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGrad.__call__" title="Link to this definition">¶</a></dt>
326
326
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">gramian</span></span></em>, <emclass="sig-param"><spanclass="positional-only-separator o"><abbrtitle="Positional-only parameter separator (PEP 570)"><spanclass="pre">/</span></abbr></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_weighting_bases.py#L72-L78"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition">¶</a></dt>
361
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.CAGradWeighting.__call__" title="Link to this definition">¶</a></dt>
362
362
<dd><p>Computes the vector of weights from the input Gramian and applies all registered hooks.</p>
<ddclass="field-odd"><p><strong>gramian</strong>(<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.11)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a></span>) – The Gramian from which the weights must be extracted.</p>
365
+
<ddclass="field-odd"><p><strong>gramian</strong> – The Gramian from which the weights must be extracted.</p>
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">ConFIG</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_config.py#L16-L61"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition">¶</a></dt>
305
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">ConFIG</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_config.py#L17-L59"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition">¶</a></dt>
306
306
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> as defined in Equation 2 of <aclass="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
307
307
Towards Conflict-free Training of Physics Informed Neural Networks</a>.</p>
308
308
<dlclass="field-list simple">
@@ -317,14 +317,14 @@ <h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading">¶</
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">matrix</span></span></em>, <emclass="sig-param"><spanclass="positional-only-separator o"><abbrtitle="Positional-only parameter separator (PEP 570)"><spanclass="pre">/</span></abbr></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aggregator_bases.py#L32-L39"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition">¶</a></dt>
320
+
<spanclass="sig-name descname"><spanclass="pre">__call__</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="o"><spanclass="pre">*</span></span><spanclass="n"><spanclass="pre">args</span></span></em>, <emclass="sig-param"><spanclass="o"><spanclass="pre">**</span></span><spanclass="n"><spanclass="pre">kwargs</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_mixins.py#L27-L29"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.ConFIG.__call__" title="Link to this definition">¶</a></dt>
321
321
<dd><p>Computes the aggregation from the input matrix and applies all registered hooks.</p>
0 commit comments