You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<dd><p>Aggregates the Jacobians stored in the <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> fields of <codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code> and accumulates the result
300
300
into their <codeclass="docutils literal notranslate"><spanclass="pre">.grad</span></code> fields.</p>
301
301
<dlclass="field-list simple">
@@ -308,6 +308,11 @@ <h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this hea
308
308
the Jacobians, <codeclass="docutils literal notranslate"><spanclass="pre">jac_to_grad</span></code> will also return the computed weights.</p></li>
309
309
<li><p><strong>retain_jac</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.14)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">bool</span></code></a></span>) – Whether to preserve the <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> fields of the tensors after they have been
310
310
used. Defaults to <codeclass="docutils literal notranslate"><spanclass="pre">False</span></code>.</p></li>
311
+
<li><p><strong>optimize_gramian_computation</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.14)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">bool</span></code></a></span>) – When the <codeclass="docutils literal notranslate"><spanclass="pre">aggregator</span></code> computes weights based on the
312
+
Gramian of the Jacobian, it’s possible to skip the concatenation of the Jacobians and to
313
+
instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This
314
+
saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We
315
+
advise to try this optimization if memory is an issue for you. Defaults to <codeclass="docutils literal notranslate"><spanclass="pre">False</span></code>.</p></li>
@@ -316,10 +321,21 @@ <h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this hea
316
321
</dl>
317
322
<divclass="admonition note">
318
323
<pclass="admonition-title">Note</p>
319
-
<p>This function starts by “flattening” the <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> fields into matrices (i.e. flattening all
320
-
of their dimensions except the first one), then concatenates those matrices into a combined
321
-
Jacobian matrix. The aggregator is then used on this matrix, which returns a combined
322
-
gradient vector, that is split and reshaped to fit into the <codeclass="docutils literal notranslate"><spanclass="pre">.grad</span></code> fields of the tensors.</p>
324
+
<p>When <codeclass="docutils literal notranslate"><spanclass="pre">optimize_gramian_computation=False</span></code>, this function starts by “flattening” the
325
+
<codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> fields into matrices (i.e. flattening all of their dimensions except the first
326
+
one), then concatenates those matrices into a combined Jacobian matrix. The <codeclass="docutils literal notranslate"><spanclass="pre">aggregator</span></code>
327
+
is then used on this matrix, which returns a combined gradient vector, that is split and
328
+
reshaped to fit into the <codeclass="docutils literal notranslate"><spanclass="pre">.grad</span></code> fields of the tensors.</p>
329
+
</div>
330
+
<divclass="admonition note">
331
+
<pclass="admonition-title">Note</p>
332
+
<p>When <codeclass="docutils literal notranslate"><spanclass="pre">optimize_gramian_computation=True</span></code>, this function computes and sums the Gramian
333
+
of each individual <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> field, iteratively. The inner weighting of the <codeclass="docutils literal notranslate"><spanclass="pre">aggregator</span></code> is
334
+
then used to extract some weights from the obtained Gramian, used to compute a linear
335
+
combination of the rows of each <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> field, to be stored into the corresponding
336
+
<codeclass="docutils literal notranslate"><spanclass="pre">.grad</span></code> field. This is mathematically equivalent to the approach with
337
+
<codeclass="docutils literal notranslate"><spanclass="pre">optimize_gramian_computation=False</span></code>, but saves memory by not having to hold the
338
+
concatenated Jacobian matrix in memory at any time.</p>
0 commit comments