|
251 | 251 | <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</a></h1> |
252 | 252 | <dl class="py class"> |
253 | 253 | <dt class="sig sig-object py" id="torchjd.autogram.Engine"> |
254 | | -<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L315"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition">¶</a></dt> |
| 254 | +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L314"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition">¶</a></dt> |
255 | 255 | <dd><p>Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct |
256 | 256 | parameters of all provided modules. It is based on Algorithm 3 of <a class="reference external" href="https://arxiv.org/pdf/2406.16232">Jacobian Descent For |
257 | 257 | Multi-Objective Optimization</a> but goes even further:</p> |
@@ -325,9 +325,8 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</ |
325 | 325 | <li><p>They should treat the elements of the batch independently. Most common layers respect |
326 | 326 | this, but for example <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html">BatchNorm</a> does not (it |
327 | 327 | computes some average and standard deviation over the elements of the batch).</p></li> |
328 | | -<li><p>Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of |
329 | | -tensors, or any nesting of those structures), but each of these tensors must be batched on |
330 | | -its first dimension. <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a> and <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html">RNNs</a> are thus not |
| 328 | +<li><p>Their inputs and outputs can be anything, but each input tensor and each output tensor |
| 329 | +must be batched on its first dimension. <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a> and <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html">RNNs</a> are thus not |
331 | 330 | supported yet. This is only an implementation issue, so it should be fixed soon (please |
332 | 331 | open an issue if you need extra focus on this).</p></li> |
333 | 332 | <li><p>They should not perform in-place operations on tensors (for instance you should not use |
@@ -372,7 +371,7 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</ |
372 | 371 | </div> |
373 | 372 | <dl class="py method"> |
374 | 373 | <dt class="sig sig-object py" id="torchjd.autogram.Engine.compute_gramian"> |
375 | | -<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L215-L284"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition">¶</a></dt> |
| 374 | +<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L214-L283"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition">¶</a></dt> |
376 | 375 | <dd><p>Computes the Gramian of the Jacobian of <code class="docutils literal notranslate"><span class="pre">output</span></code> with respect to the direct parameters of |
377 | 376 | all <code class="docutils literal notranslate"><span class="pre">modules</span></code>.</p> |
378 | 377 | <dl class="field-list simple"> |
|
0 commit comments