You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">AlignedMTL</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L39-L61"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition">¶</a></dt>
297
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">AlignedMTL</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">scale_mode</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">'min'</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L41-L74"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition">¶</a></dt>
298
298
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
299
299
<aclass="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
<ddclass="field-odd"><p><strong>pref_vector</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a> | <aclass="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><codeclass="xref py py-obj docutils literal notranslate"><spanclass="pre">None</span></code></a></span>) – The preference vector to use. If not provided, defaults to
303
-
<spanclass="math notranslate nohighlight">\(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\)</span>.</p>
302
+
<ddclass="field-odd"><ulclass="simple">
303
+
<li><p><strong>pref_vector</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a> | <aclass="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><codeclass="xref py py-obj docutils literal notranslate"><spanclass="pre">None</span></code></a></span>) – The preference vector to use. If not provided, defaults to
304
+
<spanclass="math notranslate nohighlight">\(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\)</span>.</p></li>
305
+
<li><p><strong>scale_mode</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/typing.html#typing.Literal" title="(in Python v3.14)"><codeclass="xref py py-data docutils literal notranslate"><spanclass="pre">Literal</span></code></a>[<codeclass="docutils literal notranslate"><spanclass="pre">'min'</span></code>, <codeclass="docutils literal notranslate"><spanclass="pre">'median'</span></code>, <codeclass="docutils literal notranslate"><spanclass="pre">'rmse'</span></code>]</span>) – The scaling mode used to build the balance transformation. <codeclass="docutils literal notranslate"><spanclass="pre">"min"</span></code> uses
306
+
the smallest eigenvalue (default), <codeclass="docutils literal notranslate"><spanclass="pre">"median"</span></code> uses the median eigenvalue, and <codeclass="docutils literal notranslate"><spanclass="pre">"rmse"</span></code>
307
+
uses the mean eigenvalue (as in the original implementation).</p></li>
308
+
</ul>
304
309
</dd>
305
310
</dl>
306
311
<divclass="admonition note">
@@ -311,13 +316,18 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">AlignedMTLWeighting</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L64-L101"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition">¶</a></dt>
319
+
<spanclass="property"><spanclass="k"><spanclass="pre">class</span></span><spanclass="w"></span></span><spanclass="sig-prename descclassname"><spanclass="pre">torchjd.aggregation.</span></span><spanclass="sig-name descname"><spanclass="pre">AlignedMTLWeighting</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">pref_vector</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">scale_mode</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">'min'</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L77-L135"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition">¶</a></dt>
315
320
<dd><p><aclass="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Weighting</span></code></a> giving the weights of
<ddclass="field-odd"><p><strong>pref_vector</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a> | <aclass="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><codeclass="xref py py-obj docutils literal notranslate"><spanclass="pre">None</span></code></a></span>) – The preference vector to use. If not provided, defaults to
320
-
<spanclass="math notranslate nohighlight">\(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\)</span>.</p>
324
+
<ddclass="field-odd"><ulclass="simple">
325
+
<li><p><strong>pref_vector</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a> | <aclass="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><codeclass="xref py py-obj docutils literal notranslate"><spanclass="pre">None</span></code></a></span>) – The preference vector to use. If not provided, defaults to
326
+
<spanclass="math notranslate nohighlight">\(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\)</span>.</p></li>
327
+
<li><p><strong>scale_mode</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/typing.html#typing.Literal" title="(in Python v3.14)"><codeclass="xref py py-data docutils literal notranslate"><spanclass="pre">Literal</span></code></a>[<codeclass="docutils literal notranslate"><spanclass="pre">'min'</span></code>, <codeclass="docutils literal notranslate"><spanclass="pre">'median'</span></code>, <codeclass="docutils literal notranslate"><spanclass="pre">'rmse'</span></code>]</span>) – The scaling mode used to build the balance transformation. <codeclass="docutils literal notranslate"><spanclass="pre">"min"</span></code> uses
328
+
the smallest eigenvalue (default), <codeclass="docutils literal notranslate"><spanclass="pre">"median"</span></code> uses the median eigenvalue, and <codeclass="docutils literal notranslate"><spanclass="pre">"rmse"</span></code>
329
+
uses the mean eigenvalue (as in the original implementation).</p></li>
0 commit comments