Skip to content

Commit e32812f

Browse files
committed
1 parent b2072dd commit e32812f

3 files changed

Lines changed: 7 additions & 7 deletions

File tree

latest/_sources/examples/monitoring.rst.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ they have a negative inner product).
1919
2020
import torch
2121
from torch.nn import Linear, MSELoss, ReLU, Sequential
22-
from torch.optim import SGD
2322
from torch.nn.functional import cosine_similarity
23+
from torch.optim import SGD
2424
2525
from torchjd import mtl_backward
2626
from torchjd.aggregation import UPGrad
@@ -29,7 +29,7 @@ they have a negative inner product).
2929
"""Prints the extracted weights."""
3030
print(f"Weights: {weights}")
3131
32-
def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None:
32+
def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
3333
"""Prints the cosine similarity between the aggregation and the average gradient."""
3434
matrix = inputs[0]
3535
gd_output = matrix.mean(dim=0)
@@ -50,7 +50,7 @@ they have a negative inner product).
5050
aggregator = UPGrad()
5151
5252
aggregator.weighting.register_forward_hook(print_weights)
53-
aggregator.register_forward_hook(print_similarity_with_gd)
53+
aggregator.register_forward_hook(print_gd_similarity)
5454
5555
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
5656
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task

latest/examples/monitoring/index.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
253253
they have a negative inner product).</p>
254254
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
255255
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">Linear</span><span class="p">,</span> <span class="n">MSELoss</span><span class="p">,</span> <span class="n">ReLU</span><span class="p">,</span> <span class="n">Sequential</span>
256-
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.optim</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span>
257256
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn.functional</span><span class="w"> </span><span class="kn">import</span> <span class="n">cosine_similarity</span>
257+
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.optim</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span>
258258

259259
<span class="kn">from</span><span class="w"> </span><span class="nn">torchjd</span><span class="w"> </span><span class="kn">import</span> <span class="n">mtl_backward</span>
260260
<span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">UPGrad</span>
@@ -263,7 +263,7 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
263263
</span><span class="hll"><span class="w"> </span><span class="sd">&quot;&quot;&quot;Prints the extracted weights.&quot;&quot;&quot;</span>
264264
</span><span class="hll"> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Weights: </span><span class="si">{</span><span class="n">weights</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
265265
</span>
266-
<span class="hll"><span class="k">def</span><span class="w"> </span><span class="nf">print_similarity_with_gd</span><span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">aggregation</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
266+
<span class="hll"><span class="k">def</span><span class="w"> </span><span class="nf">print_gd_similarity</span><span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">aggregation</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
267267
</span><span class="hll"><span class="w"> </span><span class="sd">&quot;&quot;&quot;Prints the cosine similarity between the aggregation and the average gradient.&quot;&quot;&quot;</span>
268268
</span><span class="hll"> <span class="n">matrix</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
269269
</span><span class="hll"> <span class="n">gd_output</span> <span class="o">=</span> <span class="n">matrix</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
@@ -284,7 +284,7 @@ <h1>Monitoring aggregations<a class="headerlink" href="#monitoring-aggregations"
284284
<span class="n">aggregator</span> <span class="o">=</span> <span class="n">UPGrad</span><span class="p">()</span>
285285

286286
<span class="hll"><span class="n">aggregator</span><span class="o">.</span><span class="n">weighting</span><span class="o">.</span><span class="n">register_forward_hook</span><span class="p">(</span><span class="n">print_weights</span><span class="p">)</span>
287-
</span><span class="hll"><span class="n">aggregator</span><span class="o">.</span><span class="n">register_forward_hook</span><span class="p">(</span><span class="n">print_similarity_with_gd</span><span class="p">)</span>
287+
</span><span class="hll"><span class="n">aggregator</span><span class="o">.</span><span class="n">register_forward_hook</span><span class="p">(</span><span class="n">print_gd_similarity</span><span class="p">)</span>
288288
</span>
289289
<span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="c1"># 8 batches of 16 random input vectors of length 10</span>
290290
<span class="n">task1_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># 8 batches of 16 targets for the first task</span>

0 commit comments

Comments
 (0)