Skip to content

Commit bf6444f

Browse files
committed
1 parent ca876c6 commit bf6444f

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

latest/docs/autojac/jac_to_grad/index.html

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,21 +295,23 @@
295295
<h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.jac_to_grad">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac_to_grad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_jac</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac_to_grad.py#L12-L81"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac_to_grad" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac_to_grad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_jac</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac_to_grad.py#L35-L129"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac_to_grad" title="Link to this definition"></a></dt>
299299
<dd><p>Aggregates the Jacobians stored in the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of <code class="docutils literal notranslate"><span class="pre">tensors</span></code> and accumulates the result
300300
into their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields.</p>
301301
<dl class="field-list simple">
302302
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
303303
<dd class="field-odd"><ul class="simple">
304304
<li><p><strong>tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]</span>) – The tensors whose <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields should be aggregated. All Jacobians must
305305
have the same first dimension (e.g. number of losses).</p></li>
306-
<li><p><strong>aggregator</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference internal" href="../../aggregation/#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a></span>) – The aggregator used to reduce the Jacobians into gradients.</p></li>
306+
<li><p><strong>aggregator</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference internal" href="../../aggregation/#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a></span>) – The aggregator used to reduce the Jacobians into gradients. If it uses a
307+
<a class="reference internal" href="../../aggregation/#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> to combine the rows of
308+
the Jacobians, <code class="docutils literal notranslate"><span class="pre">jac_to_grad</span></code> will also return the computed weights.</p></li>
307309
<li><p><strong>retain_jac</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></a></span>) – Whether to preserve the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of the tensors after they have been
308310
used. Defaults to <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li>
309311
</ul>
310312
</dd>
311313
<dt class="field-even">Return type<span class="colon">:</span></dt>
312-
<dd class="field-even"><p><span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span></p>
314+
<dd class="field-even"><p><span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span></p>
313315
</dd>
314316
</dl>
315317
<div class="admonition note">
@@ -333,13 +335,16 @@ <h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this hea
333335
<span class="gp">&gt;&gt;&gt; </span><span class="n">y2</span> <span class="o">=</span> <span class="p">(</span><span class="n">param</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
334336
<span class="gp">&gt;&gt;&gt;</span>
335337
<span class="gp">&gt;&gt;&gt; </span><span class="n">backward</span><span class="p">([</span><span class="n">y1</span><span class="p">,</span> <span class="n">y2</span><span class="p">])</span> <span class="c1"># param now has a .jac field</span>
336-
<span class="gp">&gt;&gt;&gt; </span><span class="n">jac_to_grad</span><span class="p">([</span><span class="n">param</span><span class="p">],</span> <span class="n">aggregator</span><span class="o">=</span><span class="n">UPGrad</span><span class="p">())</span> <span class="c1"># param now has a .grad field</span>
338+
<span class="gp">&gt;&gt;&gt; </span><span class="n">weights</span> <span class="o">=</span> <span class="n">jac_to_grad</span><span class="p">([</span><span class="n">param</span><span class="p">],</span> <span class="n">UPGrad</span><span class="p">())</span> <span class="c1"># param now has a .grad field</span>
337339
<span class="gp">&gt;&gt;&gt; </span><span class="n">param</span><span class="o">.</span><span class="n">grad</span>
338-
<span class="go">tensor([-1., 1.])</span>
340+
<span class="go">tensor([0.5000, 2.5000])</span>
341+
<span class="gp">&gt;&gt;&gt; </span><span class="n">weights</span>
342+
<span class="go">tensor([0.5, 0.5])</span>
339343
</pre></div>
340344
</div>
341345
<p>The <code class="docutils literal notranslate"><span class="pre">.grad</span></code> field of <code class="docutils literal notranslate"><span class="pre">param</span></code> now contains the aggregation (by UPGrad) of the Jacobian of
342-
<span class="math notranslate nohighlight">\(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\)</span> with respect to <code class="docutils literal notranslate"><span class="pre">param</span></code>.</p>
346+
<span class="math notranslate nohighlight">\(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\)</span> with respect to <code class="docutils literal notranslate"><span class="pre">param</span></code>. In this case, the
347+
weights used to combine the Jacobian are equal because there was no conflict.</p>
343348
</div>
344349
</dd></dl>
345350

0 commit comments

Comments
 (0)