Skip to content

Commit d949171

Browse files
committed
1 parent 4ac5cb3 commit d949171

8 files changed

Lines changed: 12 additions & 11 deletions

File tree

latest/_sources/examples/iwmtl.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The following example shows how to do that.
3131
optimizer = SGD(params, lr=0.1)
3232
mse = MSELoss(reduction="none")
3333
weighting = Flattening(UPGradWeighting())
34-
engine = Engine(shared_module.modules(), batch_dim=0)
34+
engine = Engine(shared_module, batch_dim=0)
3535
3636
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
3737
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task

latest/_sources/examples/iwrm.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
129129
params = model.parameters()
130130
optimizer = SGD(params, lr=0.1)
131131
weighting = UPGradWeighting()
132-
engine = Engine(model.modules(), batch_dim=0)
132+
engine = Engine(model, batch_dim=0)
133133
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]

latest/_sources/examples/partial_jd.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
3333
3434
# Create the autogram engine that will compute the Gramian of the
3535
# Jacobian with respect to the two last Linear layers' parameters.
36-
engine = Engine(model[2:].modules(), batch_dim=0)
36+
engine = Engine(model[2:], batch_dim=0)
3737
3838
params = model.parameters()
3939
optimizer = SGD(params, lr=0.1)

latest/docs/autogram/engine/index.html

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@
251251
<h1>Engine<a class="headerlink" href="#engine" title="Link to this heading"></a></h1>
252252
<dl class="py class">
253253
<dt class="sig sig-object py" id="torchjd.autogram.Engine">
254-
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L41-L318"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition"></a></dt>
254+
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.autogram.</span></span><span class="sig-name descname"><span class="pre">Engine</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">modules</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/autogram/_engine.py#L40-L318"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autogram.Engine" title="Link to this definition"></a></dt>
255255
<dd><p>Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct
256256
parameters of all provided modules. It is based on Algorithm 3 of <a class="reference external" href="https://arxiv.org/pdf/2406.16232">Jacobian Descent For
257257
Multi-Objective Optimization</a> but goes even further:</p>
@@ -270,8 +270,9 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
270270
<dl class="field-list simple">
271271
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
272272
<dd class="field-odd"><ul class="simple">
273-
<li><p><strong>modules</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module" title="(in PyTorch v2.8)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a>]</span>) – A collection of modules whose direct (non-recursive) parameters will contribute
274-
to the Gramian of the Jacobian.</p></li>
273+
<li><p><strong>modules</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module" title="(in PyTorch v2.8)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></a></span>) – The modules whose parameters will contribute to the Gramian of the Jacobian.
274+
Several modules can be provided, but it’s important that none of them is a child module of
275+
another of them.</p></li>
275276
<li><p><strong>batch_dim</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – If the modules work with batches and process each batch element independently,
276277
then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
277278
memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
@@ -300,7 +301,7 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
300301
</span><span class="hll"><span class="n">weighting</span> <span class="o">=</span> <span class="n">UPGradWeighting</span><span class="p">()</span>
301302
</span>
302303
<span class="hll"><span class="c1"># Create the engine before the backward pass, and only once.</span>
303-
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">(),</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
304+
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
304305
</span>
305306
<span class="k">for</span> <span class="nb">input</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
306307
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># shape: [16]</span>

latest/examples/iwmtl/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ <h1>Instance-Wise Multi-Task Learning (IWMTL)<a class="headerlink" href="#instan
274274
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
275275
<span class="hll"><span class="n">mse</span> <span class="o">=</span> <span class="n">MSELoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
276276
</span><span class="hll"><span class="n">weighting</span> <span class="o">=</span> <span class="n">Flattening</span><span class="p">(</span><span class="n">UPGradWeighting</span><span class="p">())</span>
277-
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">shared_module</span><span class="o">.</span><span class="n">modules</span><span class="p">(),</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
277+
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">shared_module</span><span class="p">,</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
278278
</span>
279279
<span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="c1"># 8 batches of 16 random input vectors of length 10</span>
280280
<span class="n">task1_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">)</span> <span class="c1"># 8 batches of 16 targets for the first task</span>

latest/examples/iwrm/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ <h1>Instance-Wise Risk Minimization (IWRM)<a class="headerlink" href="#instance-
369369
<span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
370370
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
371371
<span class="hll"><span class="n">weighting</span> <span class="o">=</span> <span class="n">UPGradWeighting</span><span class="p">()</span>
372-
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">(),</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
372+
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
373373
</span>
374374
<span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
375375
<span class="n">y_hat</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># shape: [16]</span>

latest/examples/partial_jd/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ <h1>Partial Jacobian Descent for IWRM<a class="headerlink" href="#partial-jacobi
276276

277277
<span class="hll"><span class="c1"># Create the autogram engine that will compute the Gramian of the</span>
278278
</span><span class="hll"><span class="c1"># Jacobian with respect to the two last Linear layers&#39; parameters.</span>
279-
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">model</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span><span class="o">.</span><span class="n">modules</span><span class="p">(),</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
279+
</span><span class="hll"><span class="n">engine</span> <span class="o">=</span> <span class="n">Engine</span><span class="p">(</span><span class="n">model</span><span class="p">[</span><span class="mi">2</span><span class="p">:],</span> <span class="n">batch_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
280280
</span>
281281
<span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
282282
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>

latest/searchindex.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)