Skip to content

Commit d2d14d0

Browse files
committed
1 parent 8614939 commit d2d14d0

2 files changed

Lines changed: 19 additions & 3 deletions

File tree

latest/docs/autogram/engine/index.html

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@
251251
<h1>Engine<a class="headerlink" href="#engine" title="Link to this heading"></a></h1>
252252
<dl class="py class">
253253
<dt class="sig sig-object py" id="torchjd.autogram.Engine">
254-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L301"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition"></a></dt>
254+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L46-L315"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition"></a></dt>
255255
<dd><p>Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct
256256
parameters of all provided modules. It is based on Algorithm 3 of <a class="reference external" href="https://arxiv.org/pdf/2406.16232">Jacobian Descent For
257257
Multi-Objective Optimization</a> but goes even further:</p>
@@ -347,6 +347,22 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
347347
<p>The alternative is to use <code class="docutils literal notranslate"><span class="pre">batch_dim=None</span></code>, but it’s not recommended since it will
348348
increase memory usage by a lot and thus typically slow down computation.</p>
349349
</div>
350+
<div class="admonition warning">
351+
<p class="admonition-title">Warning</p>
352+
<p>Parent modules should call their child modules directly rather than using their child
353+
modules’ parameters themselves. For instance, the following model is not supported:</p>
354+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="k">class</span><span class="w"> </span><span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
355+
<span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
356+
<span class="gp">&gt;&gt;&gt; </span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
357+
<span class="gp">&gt;&gt;&gt; </span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># Child module</span>
358+
<span class="gp">&gt;&gt;&gt;</span>
359+
<span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
360+
<span class="gp">&gt;&gt;&gt; </span> <span class="c1"># Incorrect: Use the child module&#39;s parameters directly without calling it.</span>
361+
<span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="nb">input</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">T</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">bias</span>
362+
<span class="gp">&gt;&gt;&gt; </span> <span class="c1"># Correct alternative: return self.linear(input)</span>
363+
</pre></div>
364+
</div>
365+
</div>
350366
<div class="admonition note">
351367
<p class="admonition-title">Note</p>
352368
<p>For maximum efficiency, modules should ideally not contain both direct trainable
@@ -356,7 +372,7 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
356372
</div>
357373
<dl class="py method">
358374
<dt class="sig sig-object py" id="torchjd.autogram.Engine.compute_gramian">
359-
<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L201-L270"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition"></a></dt>
375+
<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L215-L284"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition"></a></dt>
360376
<dd><p>Computes the Gramian of the Jacobian of <code class="docutils literal notranslate"><span class="pre">output</span></code> with respect to the direct parameters of
361377
all <code class="docutils literal notranslate"><span class="pre">modules</span></code>.</p>
362378
<dl class="field-list simple">

0 commit comments

Comments
 (0)