Skip to content

Commit 6848743

Browse files
committed
1 parent 3ebf1c3 commit 6848743

File tree

14 files changed

+14
-14
lines changed

14 files changed

+14
-14
lines changed

latest/_sources/examples/amp.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
mtl_backward(tensors=scaled_losses, features=features)
51+
mtl_backward(scaled_losses, features=features)
5252
jac_to_grad(shared_module.parameters(), aggregator)
5353
scaler.step(optimizer)
5454
scaler.update()

latest/_sources/examples/lightning_integration.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
mtl_backward(tensors=[loss1, loss2], features=features)
46+
mtl_backward([loss1, loss2], features=features)
4747
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4848
opt.step()
4949
opt.zero_grad()

latest/_sources/examples/monitoring.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
mtl_backward(tensors=[loss1, loss2], features=features)
66+
mtl_backward([loss1, loss2], features=features)
6767
jac_to_grad(shared_module.parameters(), aggregator)
6868
optimizer.step()
6969
optimizer.zero_grad()

latest/_sources/examples/mtl.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5252
loss1 = loss_fn(output1, target1)
5353
loss2 = loss_fn(output2, target2)
5454
55-
mtl_backward(tensors=[loss1, loss2], features=features)
55+
mtl_backward([loss1, loss2], features=features)
5656
jac_to_grad(shared_module.parameters(), aggregator)
5757
optimizer.step()
5858
optimizer.zero_grad()

latest/docs/autogram/engine/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
421421
</div>
422422
<dl class="py method">
423423
<dt class="sig sig-object py" id="torchjd.autogram.Engine.compute_gramian">
424-
<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autogram/_engine.py#L238-L309"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition"></a></dt>
424+
<span class="sig-name descname"><span class="pre">compute_gramian</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autogram/_engine.py#L238-L309"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine.compute_gramian" title="Link to this definition"></a></dt>
425425
<dd><p>Computes the Gramian of the Jacobian of <code class="docutils literal notranslate"><span class="pre">output</span></code> with respect to the direct parameters of
426426
all <code class="docutils literal notranslate"><span class="pre">modules</span></code>.</p>
427427
<dl class="field-list simple">

latest/docs/autojac/backward/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>backward<a class="headerlink" href="#backward" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.backward">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_backward.py#L16-L121"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.backward" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_backward.py#L16-L123"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.backward" title="Link to this definition"></a></dt>
299299
<dd><p>Computes the Jacobians of <code class="docutils literal notranslate"><span class="pre">tensors</span></code> with respect to <code class="docutils literal notranslate"><span class="pre">inputs</span></code>, left-multiplied by
300300
<code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> (or identity if <code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> is <code class="docutils literal notranslate"><span class="pre">None</span></code>), and accumulates the results in the
301301
<code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of the <code class="docutils literal notranslate"><span class="pre">inputs</span></code>.</p>

latest/docs/autojac/jac/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>jac<a class="headerlink" href="#jac" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.jac">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_outputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac.py#L20-L167"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_outputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac.py#L20-L168"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac" title="Link to this definition"></a></dt>
299299
<dd><p>Computes the Jacobians of <code class="docutils literal notranslate"><span class="pre">outputs</span></code> with respect to <code class="docutils literal notranslate"><span class="pre">inputs</span></code>, left-multiplied by
300300
<code class="docutils literal notranslate"><span class="pre">jac_outputs</span></code> (or identity if <code class="docutils literal notranslate"><span class="pre">jac_outputs</span></code> is <code class="docutils literal notranslate"><span class="pre">None</span></code>), and returns the result as a tuple,
301301
with one Jacobian per input tensor. The returned Jacobian with respect to input <code class="docutils literal notranslate"><span class="pre">t</span></code> has shape

latest/docs/autojac/jac_to_grad/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>jac_to_grad<a class="headerlink" href="#jac-to-grad" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.jac_to_grad">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac_to_grad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_jac</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac_to_grad.py#L12-L79"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac_to_grad" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">jac_to_grad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_jac</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_jac_to_grad.py#L12-L81"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.jac_to_grad" title="Link to this definition"></a></dt>
299299
<dd><p>Aggregates the Jacobians stored in the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of <code class="docutils literal notranslate"><span class="pre">tensors</span></code> and accumulates the result
300300
into their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields.</p>
301301
<dl class="field-list simple">

latest/docs/autojac/mtl_backward/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>mtl_backward<a class="headerlink" href="#mtl-backward" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.mtl_backward">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_mtl_backward.py#L25-L126"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.mtl_backward" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="positional-only-separator o"><abbr title="Positional-only parameter separator (PEP 570)"><span class="pre">/</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_mtl_backward.py#L25-L128"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.mtl_backward" title="Link to this definition"></a></dt>
299299
<dd><p>In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed
300300
by several task-specific heads. A loss can then be computed for each task.</p>
301301
<p>This function computes the gradient of each task-specific tensor with respect to its

latest/examples/amp/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ <h1>Automatic Mixed Precision (AMP)<a class="headerlink" href="#automatic-mixed-
334334
<span class="n">loss2</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">output2</span><span class="p">,</span> <span class="n">target2</span><span class="p">)</span>
335335

336336
<span class="hll"> <span class="n">scaled_losses</span> <span class="o">=</span> <span class="n">scaler</span><span class="o">.</span><span class="n">scale</span><span class="p">([</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">])</span>
337-
</span><span class="hll"> <span class="n">mtl_backward</span><span class="p">(</span><span class="n">tensors</span><span class="o">=</span><span class="n">scaled_losses</span><span class="p">,</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
337+
</span><span class="hll"> <span class="n">mtl_backward</span><span class="p">(</span><span class="n">scaled_losses</span><span class="p">,</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
338338
</span> <span class="n">jac_to_grad</span><span class="p">(</span><span class="n">shared_module</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">aggregator</span><span class="p">)</span>
339339
<span class="hll"> <span class="n">scaler</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
340340
</span><span class="hll"> <span class="n">scaler</span><span class="o">.</span><span class="n">update</span><span class="p">()</span>

0 commit comments

Comments
 (0)