|
240 | 240 | <article role="main"> |
241 | 241 | <section id="pytorch-lightning-integration"> |
242 | 242 | <h1>PyTorch Lightning Integration<a class="headerlink" href="#pytorch-lightning-integration" title="Link to this heading">¶</a></h1> |
243 | | -<p>To use Jacobian descent with TorchJD in a <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code>, you need to turn off |
244 | | -automatic optimization by setting <code class="docutils literal notranslate"><span class="pre">automatic_optimization</span></code> to <code class="docutils literal notranslate"><span class="pre">False</span></code> and to customize the |
245 | | -<code class="docutils literal notranslate"><span class="pre">training_step</span></code> method to make it call the appropriate TorchJD method (<a class="reference internal" href="../../docs/autojac/backward/"><span class="doc">backward</span></a> or <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a>).</p> |
| 243 | +<p>To use Jacobian descent with TorchJD in a <a class="reference external" href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule" title="(in PyTorch Lightning v2.5.1.post0)"><code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></a>, you need |
| 244 | +to turn off automatic optimization by setting <code class="docutils literal notranslate"><span class="pre">automatic_optimization</span></code> to <code class="docutils literal notranslate"><span class="pre">False</span></code> and to |
| 245 | +customize the <code class="docutils literal notranslate"><span class="pre">training_step</span></code> method to make it call the appropriate TorchJD method |
| 246 | +(<a class="reference internal" href="../../docs/autojac/backward/"><span class="doc">backward</span></a> or <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a>).</p> |
246 | 247 | <p>The following code example demonstrates a basic multi-task learning setup using a |
247 | | -<code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code> that will call <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a> at each training iteration.</p> |
| 248 | +<a class="reference external" href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule" title="(in PyTorch Lightning v2.5.1.post0)"><code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></a> that will call <a class="reference internal" href="../../docs/autojac/mtl_backward/"><span class="doc">mtl_backward</span></a> at each training iteration.</p> |
248 | 249 | <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> |
249 | 250 | <span class="kn">from</span><span class="w"> </span><span class="nn">lightning</span><span class="w"> </span><span class="kn">import</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span> |
250 | 251 | <span class="kn">from</span><span class="w"> </span><span class="nn">lightning.pytorch.utilities.types</span><span class="w"> </span><span class="kn">import</span> <span class="n">OptimizerLRScheduler</span> |
|
0 commit comments