Skip to content

Commit 2c4a0c6

Browse files
committed
1 parent 4cc525a commit 2c4a0c6

File tree

10 files changed

+26
-16
lines changed

10 files changed

+26
-16
lines changed

latest/_sources/examples/amp.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
mtl_backward(losses=scaled_losses, features=features)
51+
mtl_backward(tensors=scaled_losses, features=features)
5252
jac_to_grad(shared_module.parameters(), aggregator)
5353
scaler.step(optimizer)
5454
scaler.update()

latest/_sources/examples/lightning_integration.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
mtl_backward(losses=[loss1, loss2], features=features)
46+
mtl_backward(tensors=[loss1, loss2], features=features)
4747
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4848
opt.step()
4949
opt.zero_grad()

latest/_sources/examples/monitoring.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
mtl_backward(losses=[loss1, loss2], features=features)
66+
mtl_backward(tensors=[loss1, loss2], features=features)
6767
jac_to_grad(shared_module.parameters(), aggregator)
6868
optimizer.step()
6969
optimizer.zero_grad()

latest/_sources/examples/mtl.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5252
loss1 = loss_fn(output1, target1)
5353
loss2 = loss_fn(output2, target2)
5454
55-
mtl_backward(losses=[loss1, loss2], features=features)
55+
mtl_backward(tensors=[loss1, loss2], features=features)
5656
jac_to_grad(shared_module.parameters(), aggregator)
5757
optimizer.step()
5858
optimizer.zero_grad()

latest/docs/autojac/mtl_backward/index.html

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,22 +295,32 @@
295295
<h1>mtl_backward<a class="headerlink" href="#mtl-backward" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.mtl_backward">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">losses</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_mtl_backward.py#L19-L108"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.mtl_backward" title="Link to this definition"></a></dt>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_mtl_backward.py#L25-L126"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.mtl_backward" title="Link to this definition"></a></dt>
299299
<dd><p>In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed
300300
by several task-specific heads. A loss can then be computed for each task.</p>
301-
<p>This function computes the gradient of each task-specific loss with respect to its task-specific
302-
parameters and accumulates it in their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields. Then, it computes the Jacobian of all
303-
losses with respect to the shared parameters and accumulates it in their <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields.</p>
301+
<p>This function computes the gradient of each task-specific tensor with respect to its
302+
task-specific parameters and accumulates it in their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields. It also computes the
303+
Jacobian of all tensors with respect to the shared parameters and accumulates it in their
304+
<code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields. These Jacobians have one row per task.</p>
305+
<p>If the <code class="docutils literal notranslate"><span class="pre">tensors</span></code> are non-scalar, <code class="docutils literal notranslate"><span class="pre">mtl_backward</span></code> requires some initial gradients in
306+
<code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code>. This allows to compose <code class="docutils literal notranslate"><span class="pre">mtl_backward</span></code> with some other function computing
307+
the gradients with respect to the tensors (chain rule).</p>
304308
<dl class="field-list simple">
305309
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
306310
<dd class="field-odd"><ul class="simple">
307-
<li><p><strong>losses</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]</span>) – The task losses. The Jacobians will have one row per loss.</p></li>
311+
<li><p><strong>tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]</span>) – The task-specific tensors. If these are scalar (e.g. the losses produced by
312+
every task), no <code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> are needed. If these are non-scalar tensors, providing some
313+
<code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> is necessary.</p></li>
308314
<li><p><strong>features</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a></span>) – The last shared representation used for all tasks, as given by the feature
309315
extractor. Should be non-empty.</p></li>
316+
<li><p><strong>grad_tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The initial gradients to backpropagate, analog to the <code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code>
317+
parameter of <code class="docutils literal notranslate"><span class="pre">torch.autograd.backward</span></code>. If any of the <code class="docutils literal notranslate"><span class="pre">tensors</span></code> is non-scalar,
318+
<code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> must be provided, with the same length and shapes as <code class="docutils literal notranslate"><span class="pre">tensors</span></code>.
319+
Otherwise, this parameter is not needed and will default to scalars of 1.</p></li>
310320
<li><p><strong>tasks_params</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The parameters of each task-specific head. Their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code> flags
311321
must be set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. If not provided, the parameters considered for each task will
312-
default to the leaf tensors that are in the computation graph of its loss, but that were not
313-
used to compute the <code class="docutils literal notranslate"><span class="pre">features</span></code>.</p></li>
322+
default to the leaf tensors that are in the computation graph of its tensor, but that were
323+
not used to compute the <code class="docutils literal notranslate"><span class="pre">features</span></code>.</p></li>
314324
<li><p><strong>shared_params</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The parameters of the shared feature extractor. Their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code>
315325
flags must be set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. If not provided, defaults to the leaf tensors that are in the
316326
computation graph of the <code class="docutils literal notranslate"><span class="pre">features</span></code>.</p></li>

latest/examples/amp/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ <h1>Automatic Mixed Precision (AMP)<a class="headerlink" href="#automatic-mixed-
334334
<span class="n">loss2</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">output2</span><span class="p">,</span> <span class="n">target2</span><span class="p">)</span>
335335

336336
<span class="hll"> <span class="n">scaled_losses</span> <span class="o">=</span> <span class="n">scaler</span><span class="o">.</span><span class="n">scale</span><span class="p">([</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">])</span>
337-
</span><span class="hll"> <span class="n">mtl_backward</span><span class="p">(</span><span class="n">losses</span><span class="o">=</span><span class="n">scaled_losses</span><span class="p">,</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
337+
</span><span class="hll"> <span class="n">mtl_backward</span><span class="p">(</span><span class="n">tensors</span><span class="o">=</span><span class="n">scaled_losses</span><span class="p">,</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
338338
</span> <span class="n">jac_to_grad</span><span class="p">(</span><span class="n">shared_module</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">aggregator</span><span class="p">)</span>
339339
<span class="hll"> <span class="n">scaler</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
340340
</span><span class="hll"> <span class="n">scaler</span><span class="o">.</span><span class="n">update</span><span class="p">()</span>

latest/examples/lightning_integration/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ <h1>PyTorch Lightning Integration<a class="headerlink" href="#pytorch-lightning-
329329
<span class="n">loss2</span> <span class="o">=</span> <span class="n">mse_loss</span><span class="p">(</span><span class="n">output2</span><span class="p">,</span> <span class="n">target2</span><span class="p">)</span>
330330

331331
<span class="n">opt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizers</span><span class="p">()</span>
332-
<span class="hll"> <span class="n">mtl_backward</span><span class="p">(</span><span class="n">losses</span><span class="o">=</span><span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">],</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
332+
<span class="hll"> <span class="n">mtl_backward</span><span class="p">(</span><span class="n">tensors</span><span class="o">=</span><span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">],</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
333333
</span><span class="hll"> <span class="n">jac_to_grad</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">feature_extractor</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">UPGrad</span><span class="p">())</span>
334334
</span> <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
335335
<span class="n">opt</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>

latest/examples/monitoring/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
350350
<span class="n">loss1</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">output1</span><span class="p">,</span> <span class="n">target1</span><span class="p">)</span>
351351
<span class="n">loss2</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">output2</span><span class="p">,</span> <span class="n">target2</span><span class="p">)</span>
352352

353-
<span class="n">mtl_backward</span><span class="p">(</span><span class="n">losses</span><span class="o">=</span><span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">],</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
353+
<span class="n">mtl_backward</span><span class="p">(</span><span class="n">tensors</span><span class="o">=</span><span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">],</span> <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">)</span>
354354
<span class="n">jac_to_grad</span><span class="p">(</span><span class="n">shared_module</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">aggregator</span><span class="p">)</span>
355355
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
356356
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>

0 commit comments

Comments
 (0)