|
241 | 241 | <span id="nash-mtl"></span><h1>Nash-MTL<a class="headerlink" href="#module-torchjd.aggregation.nash_mtl" title="Link to this heading">¶</a></h1> |
242 | 242 | <dl class="py class"> |
243 | 243 | <dt class="sig sig-object py" id="torchjd.aggregation.nash_mtl.NashMTL"> |
244 | | -<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.nash_mtl.</span></span><span class="sig-name descname"><span class="pre">NashMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">n_tasks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_norm</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_weights_every</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optim_niter</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">20</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/nash_mtl.py#L36-L98"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.nash_mtl.NashMTL" title="Link to this definition">¶</a></dt> |
| 244 | +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.nash_mtl.</span></span><span class="sig-name descname"><span class="pre">NashMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">n_tasks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_norm</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_weights_every</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optim_niter</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">20</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/nash_mtl.py#L39-L103"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.nash_mtl.NashMTL" title="Link to this definition">¶</a></dt> |
245 | 245 | <dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as proposed in Algorithm 1 of |
246 | 246 | <a class="reference external" href="https://arxiv.org/pdf/2202.01017.pdf">Multi-Task Learning as a Bargaining Game</a>.</p> |
247 | 247 | <dl class="field-list simple"> |
|
273 | 273 | </div> |
274 | 274 | <div class="admonition note"> |
275 | 275 | <p class="admonition-title">Note</p> |
276 | | -<p>This aggregator has dependencies that are not included by default when installing |
277 | | -<code class="docutils literal notranslate"><span class="pre">torchjd</span></code>. To install them, use <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span> <span class="pre">torchjd[nash_mtl]</span></code>.</p> |
| 276 | +<p>This aggregator is not installed by default. When not installed, trying to import it should |
| 277 | +result in the following error: |
| 278 | +<code class="docutils literal notranslate"><span class="pre">ImportError:</span> <span class="pre">cannot</span> <span class="pre">import</span> <span class="pre">name</span> <span class="pre">'NashMTL'</span> <span class="pre">from</span> <span class="pre">'torchjd.aggregation'</span></code>. |
| 279 | +To install it, use <code class="docutils literal notranslate"><span class="pre">pip</span> <span class="pre">install</span> <span class="pre">torchjd[nash_mtl]</span></code>.</p> |
278 | 280 | </div> |
279 | 281 | <div class="admonition warning"> |
280 | 282 | <p class="admonition-title">Warning</p> |
|
288 | 290 | </div> |
289 | 291 | <dl class="py method"> |
290 | 292 | <dt class="sig sig-object py" id="torchjd.aggregation.nash_mtl.NashMTL.reset"> |
291 | | -<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/nash_mtl.py#L93-L95"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.nash_mtl.NashMTL.reset" title="Link to this definition">¶</a></dt> |
| 293 | +<span class="sig-name descname"><span class="pre">reset</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/nash_mtl.py#L98-L100"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.nash_mtl.NashMTL.reset" title="Link to this definition">¶</a></dt> |
292 | 294 | <dd><p>Resets the internal state of the algorithm.</p> |
293 | 295 | <dl class="field-list simple"> |
294 | 296 | <dt class="field-odd">Return type<span class="colon">:</span></dt> |
|
0 commit comments