Skip to content

Commit 9db9f5d

Browse files
committed
1 parent 4814022 commit 9db9f5d

2 files changed

Lines changed: 16 additions & 11 deletions

File tree

latest/docs/autogram/engine/index.html

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@
251251
<h1>Engine<a class="headerlink" href="#engine" title="Link to this heading"></a></h1>
252252
<dl class="py class">
253253
<dt class="sig sig-object py" id="torchjd.autogram.Engine">
254-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L314"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition"></a></dt>
254+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L323"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition"></a></dt>
255255
<dd><p>Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct
256256
parameters of all provided modules. It is based on Algorithm 3 of <a class="reference external" href="https://arxiv.org/pdf/2406.16232">Jacobian Descent For
257257
Multi-Objective Optimization</a> but goes even further:</p>
@@ -319,24 +319,29 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
319319
</div>
320320
<div class="admonition warning">
321321
<p class="admonition-title">Warning</p>
322-
<p>When providing a non-None <code class="docutils literal notranslate"><span class="pre">batch_dim</span></code>, all provided modules must respect a few
323-
conditions:</p>
322+
<p>When providing a non-None <code class="docutils literal notranslate"><span class="pre">batch_dim</span></code>, all provided modules must respect a few conditions:</p>
324323
<ul class="simple">
325324
<li><p>They should treat the elements of the batch independently. Most common layers respect
326325
this, but for example <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html">BatchNorm</a> does not (it
327326
computes some average and standard deviation over the elements of the batch).</p></li>
328327
<li><p>Their inputs and outputs can be anything, but each input tensor and each output tensor
329-
must be batched on its first dimension. <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a> and <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html">RNNs</a> are thus not
330-
supported yet. This is only an implementation issue, so it should be fixed soon (please
331-
open an issue if you need extra focus on this).</p></li>
328+
must be batched on its first dimension. When available (e.g. in <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a>,
329+
<a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html">MultiheadAttention</a>,
330+
etc.), the <code class="docutils literal notranslate"><span class="pre">batch_first</span></code> parameter has to be set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. Also, this makes <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html">RNNs</a> not supported yet
331+
because their hidden state is batched on dimension 1 even if <code class="docutils literal notranslate"><span class="pre">batch_first</span></code> is <code class="docutils literal notranslate"><span class="pre">True</span></code>.</p></li>
332332
<li><p>They should not perform in-place operations on tensors (for instance you should not use
333333
<code class="docutils literal notranslate"><span class="pre">track_running_stats=True</span></code> in normalization layers).</p></li>
334334
<li><p>They should not have side effects during the forward pass (since their forward pass will
335335
be called twice, the side effects could be different from what’s expected).</p></li>
336336
<li><p>If they have some randomness during the forward pass, they should not have direct
337-
trainable parameters. It is, however, perfectly fine for random modules to have child
338-
modules that have trainable parameters, so if you have a random module with some direct
339-
parameters, a simple fix is to wrap these parameters into a child module.</p></li>
337+
trainable parameters. For this reason,
338+
<a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html">Transformers</a>, which use a
339+
dropout function (rather than a <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html">Dropout</a> layer) in a
340+
module with some trainable parameters, has to be used with
341+
<code class="docutils literal notranslate"><span class="pre">dropout=0.0</span></code>. Note that a <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html">Dropout</a> layers are
342+
entirely supported and should be preferred. It is also perfectly fine for random modules
343+
to have child modules that have trainable parameters, so if you have a random module with
344+
some direct parameters, a simple fix is to wrap these parameters into a child module.</p></li>
340345
</ul>
341346
<p>If you’re building your own architecture, respecting those criteria should be quite easy.
342347
However, if you’re using an existing architecture, you may have to modify it to make it
@@ -371,7 +376,7 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
371376
</div>
372377
<dl class="py method">
373378
<dt class="sig sig-object py" id="torchjd.autogram.Engine.compute_gramian">
374-
<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L214-L283"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition"></a></dt>
379+
<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L223-L292"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition"></a></dt>
375380
<dd><p>Computes the Gramian of the Jacobian of <code class="docutils literal notranslate"><span class="pre">output</span></code> with respect to the direct parameters of
376381
all <code class="docutils literal notranslate"><span class="pre">modules</span></code>.</p>
377382
<dl class="field-list simple">

0 commit comments

Comments
 (0)