Skip to content

Commit 542837f

Browse files
committed
1 parent d303b97 commit 542837f

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

latest/docs/autojac/backward/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@
242242
<h1>backward<a class="headerlink" href="#backward" title="Link to this heading"></a></h1>
243243
<dl class="py function">
244244
<dt class="sig sig-object py" id="torchjd.backward">
245-
<span class="sig-prename descclassname"><span class="pre">torchjd.</span></span><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/_autojac/_backward.py#L20-L98"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.backward" title="Link to this definition"></a></dt>
245+
<span class="sig-prename descclassname"><span class="pre">torchjd.</span></span><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/_autojac/_backward.py#L11-L89"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.backward" title="Link to this definition"></a></dt>
246246
<dd><p>Computes the Jacobian of all values in <code class="docutils literal notranslate"><span class="pre">tensors</span></code> with respect to all <code class="docutils literal notranslate"><span class="pre">inputs</span></code>. Computes its
247247
aggregation by the provided <code class="docutils literal notranslate"><span class="pre">aggregator</span></code> and accumulates it in the <code class="docutils literal notranslate"><span class="pre">.grad</span></code> fields of the
248248
<code class="docutils literal notranslate"><span class="pre">inputs</span></code>.</p>

latest/docs/autojac/mtl_backward/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@
242242
<h1>mtl_backward<a class="headerlink" href="#mtl-backward" title="Link to this heading"></a></h1>
243243
<dl class="py function">
244244
<dt class="sig sig-object py" id="torchjd.mtl_backward">
245-
<span class="sig-prename descclassname"><span class="pre">torchjd.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">losses</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/_autojac/_mtl_backward.py#L23-L117"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.mtl_backward" title="Link to this definition"></a></dt>
245+
<span class="sig-prename descclassname"><span class="pre">torchjd.</span></span><span class="sig-name descname"><span class="pre">mtl_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">losses</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggregator</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shared_params</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/_autojac/_mtl_backward.py#L11-L105"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.mtl_backward" title="Link to this definition"></a></dt>
246246
<dd><p>In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed
247247
by several task-specific heads. A loss can then be computed for each task.</p>
248248
<p>This function computes the gradient of each task-specific loss with respect to its task-specific

0 commit comments

Comments
 (0)