Skip to content

Commit 9479440

Browse files
committed
1 parent bf6444f commit 9479440

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

latest/docs/autojac/jac_to_grad/index.html

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.jac_to_grad">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac_to_grad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_jac</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac_to_grad.py#L35-L129"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac_to_grad" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac_to_grad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_jac</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optimize_gramian_computation</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac_to_grad.py#L47-L146"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac_to_grad" title="Link to this definition"></a></dt>
299299
<dd><p>Aggregates the Jacobians stored in the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of <code class="docutils literal notranslate"><span class="pre">tensors</span></code> and accumulates the result
300300
into their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields.</p>
301301
<dl class="field-list simple">
@@ -308,6 +308,11 @@ <h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this hea
308308
the Jacobians, <code class="docutils literal notranslate"><span class="pre">jac_to_grad</span></code> will also return the computed weights.</p></li>
309309
<li><p><strong>retain_jac</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></a></span>) – Whether to preserve the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of the tensors after they have been
310310
used. Defaults to <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li>
311+
<li><p><strong>optimize_gramian_computation</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></a></span>) – When the <code class="docutils literal notranslate"><span class="pre">aggregator</span></code> computes weights based on the
312+
Gramian of the Jacobian, it’s possible to skip the concatenation of the Jacobians and to
313+
instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This
314+
saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We
315+
advise to try this optimization if memory is an issue for you. Defaults to <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li>
311316
</ul>
312317
</dd>
313318
<dt class="field-even">Return type<span class="colon">:</span></dt>
@@ -316,10 +321,21 @@ <h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this hea
316321
</dl>
317322
<div class="admonition note">
318323
<p class="admonition-title">Note</p>
319-
<p>This function starts by “flattening” the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields into matrices (i.e. flattening all
320-
of their dimensions except the first one), then concatenates those matrices into a combined
321-
Jacobian matrix. The aggregator is then used on this matrix, which returns a combined
322-
gradient vector, that is split and reshaped to fit into the <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields of the tensors.</p>
324+
<p>When <code class="docutils literal notranslate"><span class="pre">optimize_gramian_computation=False</span></code>, this function starts by “flattening” the
325+
<code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields into matrices (i.e. flattening all of their dimensions except the first
326+
one), then concatenates those matrices into a combined Jacobian matrix. The <code class="docutils literal notranslate"><span class="pre">aggregator</span></code>
327+
is then used on this matrix, which returns a combined gradient vector, that is split and
328+
reshaped to fit into the <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields of the tensors.</p>
329+
</div>
330+
<div class="admonition note">
331+
<p class="admonition-title">Note</p>
332+
<p>When <code class="docutils literal notranslate"><span class="pre">optimize_gramian_computation=True</span></code>, this function computes and sums the Gramian
333+
of each individual <code class="docutils literal notranslate"><span class="pre">.jac</span></code> field, iteratively. The inner weighting of the <code class="docutils literal notranslate"><span class="pre">aggregator</span></code> is
334+
then used to extract some weights from the obtained Gramian, used to compute a linear
335+
combination of the rows of each <code class="docutils literal notranslate"><span class="pre">.jac</span></code> field, to be stored into the corresponding
336+
<code class="docutils literal notranslate"><span class="pre">.grad</span></code> field. This is mathematically equivalent to the approach with
337+
<code class="docutils literal notranslate"><span class="pre">optimize_gramian_computation=False</span></code>, but saves memory by not having to hold the
338+
concatenated Jacobian matrix in memory at any time.</p>
323339
</div>
324340
<div class="admonition-example admonition">
325341
<p class="admonition-title">Example</p>

0 commit comments

Comments
 (0)