|
295 | 295 | <h1>mtl_backward<a class="headerlink" href="#mtl-backward" title="Link to this heading">¶</a></h1> |
296 | 296 | <dl class="py function"> |
297 | 297 | <dt class="sig sig-object py" id="torchjd.autojac.mtl_backward"> |
298 | | -<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">losses</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_mtl_backward.py#L19-L108"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.mtl_backward" title="Link to this definition">¶</a></dt> |
| 298 | +<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_mtl_backward.py#L25-L126"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.mtl_backward" title="Link to this definition">¶</a></dt> |
299 | 299 | <dd><p>In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed |
300 | 300 | by several task-specific heads. A loss can then be computed for each task.</p> |
301 | | -<p>This function computes the gradient of each task-specific loss with respect to its task-specific |
302 | | -parameters and accumulates it in their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields. Then, it computes the Jacobian of all |
303 | | -losses with respect to the shared parameters and accumulates it in their <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields.</p> |
| 301 | +<p>This function computes the gradient of each task-specific tensor with respect to its |
| 302 | +task-specific parameters and accumulates it in their <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields. It also computes the |
| 303 | +Jacobian of all tensors with respect to the shared parameters and accumulates it in their |
| 304 | +<code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields. These Jacobians have one row per task.</p> |
| 305 | +<p>If the <code class="docutils literal notranslate"><span class="pre">tensors</span></code> are non-scalar, <code class="docutils literal notranslate"><span class="pre">mtl_backward</span></code> requires some initial gradients in |
| 306 | +<code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code>. This allows to compose <code class="docutils literal notranslate"><span class="pre">mtl_backward</span></code> with some other function computing |
| 307 | +the gradients with respect to the tensors (chain rule).</p> |
304 | 308 | <dl class="field-list simple"> |
305 | 309 | <dt class="field-odd">Parameters<span class="colon">:</span></dt> |
306 | 310 | <dd class="field-odd"><ul class="simple"> |
307 | | -<li><p><strong>losses</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]</span>) – The task losses. The Jacobians will have one row per loss.</p></li> |
| 311 | +<li><p><strong>tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]</span>) – The task-specific tensors. If these are scalar (e.g. the losses produced by |
| 312 | +every task), no <code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> are needed. If these are non-scalar tensors, providing some |
| 313 | +<code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> is necessary.</p></li> |
308 | 314 | <li><p><strong>features</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a></span>) – The last shared representation used for all tasks, as given by the feature |
309 | 315 | extractor. Should be non-empty.</p></li> |
| 316 | +<li><p><strong>grad_tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The initial gradients to backpropagate, analog to the <code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> |
| 317 | +parameter of <code class="docutils literal notranslate"><span class="pre">torch.autograd.backward</span></code>. If any of the <code class="docutils literal notranslate"><span class="pre">tensors</span></code> is non-scalar, |
| 318 | +<code class="docutils literal notranslate"><span class="pre">grad_tensors</span></code> must be provided, with the same length and shapes as <code class="docutils literal notranslate"><span class="pre">tensors</span></code>. |
| 319 | +Otherwise, this parameter is not needed and will default to scalars of 1.</p></li> |
310 | 320 | <li><p><strong>tasks_params</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>]] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The parameters of each task-specific head. Their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code> flags |
311 | 321 | must be set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. If not provided, the parameters considered for each task will |
312 | | -default to the leaf tensors that are in the computation graph of its loss, but that were not |
313 | | -used to compute the <code class="docutils literal notranslate"><span class="pre">features</span></code>.</p></li> |
| 322 | +default to the leaf tensors that are in the computation graph of its tensor, but that were |
| 323 | +not used to compute the <code class="docutils literal notranslate"><span class="pre">features</span></code>.</p></li> |
314 | 324 | <li><p><strong>shared_params</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The parameters of the shared feature extractor. Their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code> |
315 | 325 | flags must be set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. If not provided, defaults to the leaf tensors that are in the |
316 | 326 | computation graph of the <code class="docutils literal notranslate"><span class="pre">features</span></code>.</p></li> |
|
0 commit comments