|
251 | 251 | <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</a></h1> |
252 | 252 | <dl class="py class"> |
253 | 253 | <dt class="sig sig-object py" id="torchjd.autogram.Engine"> |
254 | | -<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L314"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition">¶</a></dt> |
| 254 | +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L323"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition">¶</a></dt> |
255 | 255 | <dd><p>Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct |
256 | 256 | parameters of all provided modules. It is based on Algorithm 3 of <a class="reference external" href="https://arxiv.org/pdf/2406.16232">Jacobian Descent For |
257 | 257 | Multi-Objective Optimization</a> but goes even further:</p> |
@@ -319,24 +319,29 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</ |
319 | 319 | </div> |
320 | 320 | <div class="admonition warning"> |
321 | 321 | <p class="admonition-title">Warning</p> |
322 | | -<p>When providing a non-None <code class="docutils literal notranslate"><span class="pre">batch_dim</span></code>, all provided modules must respect a few |
323 | | -conditions:</p> |
| 322 | +<p>When providing a non-None <code class="docutils literal notranslate"><span class="pre">batch_dim</span></code>, all provided modules must respect a few conditions:</p> |
324 | 323 | <ul class="simple"> |
325 | 324 | <li><p>They should treat the elements of the batch independently. Most common layers respect |
326 | 325 | this, but for example <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html">BatchNorm</a> does not (it |
327 | 326 | computes some average and standard deviation over the elements of the batch).</p></li> |
328 | 327 | <li><p>Their inputs and outputs can be anything, but each input tensor and each output tensor |
329 | | -must be batched on its first dimension. <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a> and <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html">RNNs</a> are thus not |
330 | | -supported yet. This is only an implementation issue, so it should be fixed soon (please |
331 | | -open an issue if you need extra focus on this).</p></li> |
| 328 | +must be batched on its first dimension. When available (e.g. in <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a>, |
| 329 | +<a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html">MultiheadAttention</a>, |
| 330 | +etc.), the <code class="docutils literal notranslate"><span class="pre">batch_first</span></code> parameter has to be set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. Also, this makes <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html">RNNs</a> not supported yet |
| 331 | +because their hidden state is batched on dimension 1 even if <code class="docutils literal notranslate"><span class="pre">batch_first</span></code> is <code class="docutils literal notranslate"><span class="pre">True</span></code>.</p></li> |
332 | 332 | <li><p>They should not perform in-place operations on tensors (for instance you should not use |
333 | 333 | <code class="docutils literal notranslate"><span class="pre">track_running_stats=True</span></code> in normalization layers).</p></li> |
334 | 334 | <li><p>They should not have side effects during the forward pass (since their forward pass will |
335 | 335 | be called twice, the side effects could be different from what’s expected).</p></li> |
336 | 336 | <li><p>If they have some randomness during the forward pass, they should not have direct |
337 | | -trainable parameters. It is, however, perfectly fine for random modules to have child |
338 | | -modules that have trainable parameters, so if you have a random module with some direct |
339 | | -parameters, a simple fix is to wrap these parameters into a child module.</p></li> |
| 337 | +trainable parameters. For this reason, |
| 338 | +<a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a>, which use a |
| 339 | +dropout function (rather than a <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html">Dropout</a> layer) in a |
| 340 | +module with some trainable parameters, has to be used with |
| 341 | +<code class="docutils literal notranslate"><span class="pre">dropout=0.0</span></code>. Note that a <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html">Dropout</a> layers are |
| 342 | +entirely supported and should be preferred. It is also perfectly fine for random modules |
| 343 | +to have child modules that have trainable parameters, so if you have a random module with |
| 344 | +some direct parameters, a simple fix is to wrap these parameters into a child module.</p></li> |
340 | 345 | </ul> |
341 | 346 | <p>If you’re building your own architecture, respecting those criteria should be quite easy. |
342 | 347 | However, if you’re using an existing architecture, you may have to modify it to make it |
@@ -371,7 +376,7 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</ |
371 | 376 | </div> |
372 | 377 | <dl class="py method"> |
373 | 378 | <dt class="sig sig-object py" id="torchjd.autogram.Engine.compute_gramian"> |
374 | | -<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L214-L283"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition">¶</a></dt> |
| 379 | +<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L223-L292"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition">¶</a></dt> |
375 | 380 | <dd><p>Computes the Gramian of the Jacobian of <code class="docutils literal notranslate"><span class="pre">output</span></code> with respect to the direct parameters of |
376 | 381 | all <code class="docutils literal notranslate"><span class="pre">modules</span></code>.</p> |
377 | 382 | <dl class="field-list simple"> |
|
0 commit comments