Skip to content

Commit 97da61e

Browse files
committed
1 parent 6848743 commit 97da61e

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

latest/docs/autojac/jac/index.html

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>jac<a class="headerlink" href="#jac" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.jac">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_outputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac.py#L20-L168"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_outputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac.py#L19-L163"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac" title="Link to this definition"></a></dt>
299299
<dd><p>Computes the Jacobians of <code class="docutils literal notranslate"><span class="pre">outputs</span></code> with respect to <code class="docutils literal notranslate"><span class="pre">inputs</span></code>, left-multiplied by
300300
<code class="docutils literal notranslate"><span class="pre">jac_outputs</span></code> (or identity if <code class="docutils literal notranslate"><span class="pre">jac_outputs</span></code> is <code class="docutils literal notranslate"><span class="pre">None</span></code>), and returns the result as a tuple,
301301
with one Jacobian per input tensor. The returned Jacobian with respect to input <code class="docutils literal notranslate"><span class="pre">t</span></code> has shape
@@ -304,9 +304,8 @@ <h1>jac<a class="headerlink" href="#jac" title="Link to this heading">¶</a></h1
304304
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
305305
<dd class="field-odd"><ul class="simple">
306306
<li><p><strong>outputs</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a></span>) – The tensor or tensors to differentiate. Should be non-empty.</p></li>
307-
<li><p><strong>inputs</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The tensors with respect to which the Jacobian must be computed. These must have
308-
their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code> flag set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. If not provided, defaults to the leaf tensors
309-
that were used to compute the <code class="docutils literal notranslate"><span class="pre">outputs</span></code> parameter.</p></li>
307+
<li><p><strong>inputs</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a></span>) – The tensor or tensors with respect to which the Jacobian must be computed. These
308+
must have their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code> flag set to <code class="docutils literal notranslate"><span class="pre">True</span></code>.</p></li>
310309
<li><p><strong>jac_outputs</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The initial Jacobians to backpropagate, analog to the <code class="docutils literal notranslate"><span class="pre">grad_outputs</span></code>
311310
parameter of <a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad.html#torch.autograd.grad" title="(in PyTorch v2.10)"><code class="xref py py-func docutils literal notranslate"><span class="pre">torch.autograd.grad()</span></code></a>. If provided, it must have the same structure as
312311
<code class="docutils literal notranslate"><span class="pre">outputs</span></code> and each tensor in <code class="docutils literal notranslate"><span class="pre">jac_outputs</span></code> must match the shape of the corresponding
@@ -345,7 +344,7 @@ <h1>jac<a class="headerlink" href="#jac" title="Link to this heading">¶</a></h1
345344
<span class="gp">&gt;&gt;&gt; </span><span class="n">y1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">])</span> <span class="o">@</span> <span class="n">param</span>
346345
<span class="gp">&gt;&gt;&gt; </span><span class="n">y2</span> <span class="o">=</span> <span class="p">(</span><span class="n">param</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
347346
<span class="gp">&gt;&gt;&gt;</span>
348-
<span class="gp">&gt;&gt;&gt; </span><span class="n">jacobians</span> <span class="o">=</span> <span class="n">jac</span><span class="p">([</span><span class="n">y1</span><span class="p">,</span> <span class="n">y2</span><span class="p">],</span> <span class="p">[</span><span class="n">param</span><span class="p">])</span>
347+
<span class="gp">&gt;&gt;&gt; </span><span class="n">jacobians</span> <span class="o">=</span> <span class="n">jac</span><span class="p">([</span><span class="n">y1</span><span class="p">,</span> <span class="n">y2</span><span class="p">],</span> <span class="n">param</span><span class="p">)</span>
349348
<span class="gp">&gt;&gt;&gt;</span>
350349
<span class="gp">&gt;&gt;&gt; </span><span class="n">jacobians</span>
351350
<span class="go">(tensor([[-1., 1.],</span>
@@ -405,14 +404,14 @@ <h1>jac<a class="headerlink" href="#jac" title="Link to this heading">¶</a></h1
405404
<span class="gp">&gt;&gt;&gt; </span><span class="n">jac_h</span> <span class="o">=</span> <span class="n">jac</span><span class="p">([</span><span class="n">y1</span><span class="p">,</span> <span class="n">y2</span><span class="p">],</span> <span class="p">[</span><span class="n">h</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Shape: [2, 2]</span>
406405
<span class="gp">&gt;&gt;&gt;</span>
407406
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)</span>
408-
<span class="gp">&gt;&gt;&gt; </span><span class="n">jac_x</span> <span class="o">=</span> <span class="n">jac</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="p">[</span><span class="n">x</span><span class="p">],</span> <span class="n">jac_outputs</span><span class="o">=</span><span class="n">jac_h</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
407+
<span class="gp">&gt;&gt;&gt; </span><span class="n">jac_x</span> <span class="o">=</span> <span class="n">jac</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">jac_outputs</span><span class="o">=</span><span class="n">jac_h</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
409408
<span class="gp">&gt;&gt;&gt;</span>
410409
<span class="gp">&gt;&gt;&gt; </span><span class="n">jac_x</span>
411410
<span class="go">tensor([[ 2., 4.],</span>
412411
<span class="go"> [ 2., -4.]])</span>
413412
</pre></div>
414413
</div>
415-
<p>This two-step computation is equivalent to directly computing <code class="docutils literal notranslate"><span class="pre">jac([y1,</span> <span class="pre">y2],</span> <span class="pre">[x])</span></code>.</p>
414+
<p>This two-step computation is equivalent to directly computing <code class="docutils literal notranslate"><span class="pre">jac([y1,</span> <span class="pre">y2],</span> <span class="pre">x)</span></code>.</p>
416415
</div>
417416
<div class="admonition warning">
418417
<p class="admonition-title">Warning</p>

0 commit comments

Comments
 (0)