Skip to content

Commit be7bdad

Browse files
fix(autojac): Fix differentiation with repeated tensors (#277)
* Add tests with repeated tensors for backward * Add tests with repeated tensors for mtl_backward * Remove check of uniqueness of the elements provided to ordered_set * Remove casting outputs to ordered_set in _Differentiate * Add changelog entry --------- Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
1 parent 99e8bea commit be7bdad

File tree

5 files changed

+170
-11
lines changed

5 files changed

+170
-11
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ changes that do not affect the user.
1818
project onto the dual cone. This may minimally affect the output of these aggregators.
1919

2020
### Fixed
21+
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they
22+
appear several times in a list of tensors provided as argument). Instead of raising an exception
23+
in these cases, we are now aligned with the behavior of `torch.autograd.backward`. Repeated
24+
tensors that we differentiate lead to repeated rows in the Jacobian, prior to aggregation, and
25+
repeated tensors with respect to which we differentiate count only once.
2126
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
2227
practice, this fix should only affect some matrices with extremely large values, which should
2328
not usually happen.

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
retain_graph: bool,
1616
create_graph: bool,
1717
):
18-
self.outputs = ordered_set(outputs)
18+
self.outputs = list(outputs)
1919
self.inputs = ordered_set(inputs)
2020
self.retain_graph = retain_graph
2121
self.create_graph = create_graph

src/torchjd/autojac/_transform/_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,7 @@
1616

1717

1818
def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
19-
elements = list(elements)
20-
result = OrderedDict.fromkeys(elements, None)
21-
if len(elements) != len(result):
22-
raise ValueError(
23-
f"Parameter `elements` should contain unique elements. Found `elements = {elements}`."
24-
)
25-
26-
return result
19+
return OrderedDict.fromkeys(elements, None)
2720

2821

2922
def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:

tests/unit/autojac/test_backward.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
from pytest import mark, raises
3+
from torch.autograd import grad
34
from torch.testing import assert_close
45

56
from torchjd import backward
6-
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad
7+
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
78

89

910
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])
@@ -214,3 +215,47 @@ def test_tensor_used_multiple_times(chunk_size: int | None):
214215
)
215216

216217
assert_close(a.grad, aggregator(expected_jacobian).squeeze())
218+
219+
220+
def test_repeated_tensors():
221+
"""
222+
Tests that backward correctly works when some tensors are repeated. In this case, since
223+
torch.autograd.backward would sum the gradients of the repeated tensors, it is natural for
224+
autojac to compute a Jacobian with one row per repeated tensor, and to aggregate it.
225+
"""
226+
227+
a1 = torch.tensor([1.0, 2.0], requires_grad=True)
228+
a2 = torch.tensor([3.0, 4.0], requires_grad=True)
229+
230+
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
231+
y2 = (a1**2).sum() + (a2**2).sum()
232+
233+
expected_grad_wrt_a1 = grad([y1, y1, y2], a1, retain_graph=True)[0]
234+
expected_grad_wrt_a2 = grad([y1, y1, y2], a2, retain_graph=True)[0]
235+
236+
backward([y1, y1, y2], Sum())
237+
238+
assert_close(a1.grad, expected_grad_wrt_a1)
239+
assert_close(a2.grad, expected_grad_wrt_a2)
240+
241+
242+
def test_repeated_inputs():
243+
"""
244+
Tests that backward correctly works when some inputs are repeated. In this case, since
245+
torch.autograd.backward ignores the repetition of the inputs, it is natural for autojac to
246+
ignore that as well.
247+
"""
248+
249+
a1 = torch.tensor([1.0, 2.0], requires_grad=True)
250+
a2 = torch.tensor([3.0, 4.0], requires_grad=True)
251+
252+
y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
253+
y2 = (a1**2).sum() + (a2**2).sum()
254+
255+
expected_grad_wrt_a1 = grad([y1, y2], a1, retain_graph=True)[0]
256+
expected_grad_wrt_a2 = grad([y1, y2], a2, retain_graph=True)[0]
257+
258+
backward([y1, y2], Sum(), inputs=[a1, a1, a2])
259+
260+
assert_close(a1.grad, expected_grad_wrt_a1)
261+
assert_close(a2.grad, expected_grad_wrt_a2)

tests/unit/autojac/test_mtl_backward.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
from pytest import mark, raises
3+
from torch.autograd import grad
34
from torch.testing import assert_close
45

56
from torchjd import mtl_backward
6-
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad
7+
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
78

89

910
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])
@@ -557,3 +558,118 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails():
557558
features=[f],
558559
aggregator=UPGrad(),
559560
)
561+
562+
563+
def test_repeated_losses():
564+
"""
565+
Tests that mtl_backward correctly works when some losses are repeated. In this case, since
566+
torch.autograd.backward would sum the gradients of the repeated losses, it is natural for
567+
autojac to sum the task-specific gradients, and to compute and aggregate a Jacobian with one row
568+
per repeated tensor, for shared gradients.
569+
"""
570+
571+
p0 = torch.tensor([1.0, 2.0], requires_grad=True)
572+
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
573+
p2 = torch.tensor([3.0, 4.0], requires_grad=True)
574+
575+
f1 = torch.tensor([-1.0, 1.0]) @ p0
576+
f2 = (p0**2).sum() + p0.norm()
577+
y1 = f1 * p1[0] + f2 * p1[1]
578+
y2 = f1 * p2[0] + f2 * p2[1]
579+
580+
expected_grad_wrt_p0 = grad([y1, y1, y2], [p0], retain_graph=True)[0]
581+
expected_grad_wrt_p1 = grad([y1, y1], [p1], retain_graph=True)[0]
582+
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]
583+
584+
losses = [y1, y1, y2]
585+
mtl_backward(losses=losses, features=[f1, f2], aggregator=Sum(), retain_graph=True)
586+
587+
assert_close(p0.grad, expected_grad_wrt_p0)
588+
assert_close(p1.grad, expected_grad_wrt_p1)
589+
assert_close(p2.grad, expected_grad_wrt_p2)
590+
591+
592+
def test_repeated_features():
593+
"""
594+
Tests that mtl_backward correctly works when some features are repeated. Repeated features are
595+
a bit more tricky, because we differentiate with respect to them (in which case it shouldn't
596+
matter that they are repeated) and we also differentiate them (in which case it should lead to
597+
extra rows in the Jacobian).
598+
"""
599+
600+
p0 = torch.tensor([1.0, 2.0], requires_grad=True)
601+
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
602+
p2 = torch.tensor([3.0, 4.0], requires_grad=True)
603+
604+
f1 = torch.tensor([-1.0, 1.0]) @ p0
605+
f2 = (p0**2).sum() + p0.norm()
606+
y1 = f1 * p1[0] + f2 * p1[1]
607+
y2 = f1 * p2[0] + f2 * p2[1]
608+
609+
grad_outputs = grad([y1, y2], [f1, f1, f2], retain_graph=True)
610+
expected_grad_wrt_p0 = grad([f1, f1, f2], [p0], grad_outputs, retain_graph=True)[0]
611+
expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0]
612+
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]
613+
614+
features = [f1, f1, f2]
615+
mtl_backward(losses=[y1, y2], features=features, aggregator=Sum())
616+
617+
assert_close(p0.grad, expected_grad_wrt_p0)
618+
assert_close(p1.grad, expected_grad_wrt_p1)
619+
assert_close(p2.grad, expected_grad_wrt_p2)
620+
621+
622+
def test_repeated_shared_params():
623+
"""
624+
Tests that mtl_backward correctly works when some shared are repeated. Since these are tensors
625+
with respect to which we differentiate, to match the behavior of torch.autograd.backward, this
626+
repetition should not affect the result.
627+
"""
628+
629+
p0 = torch.tensor([1.0, 2.0], requires_grad=True)
630+
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
631+
p2 = torch.tensor([3.0, 4.0], requires_grad=True)
632+
633+
f1 = torch.tensor([-1.0, 1.0]) @ p0
634+
f2 = (p0**2).sum() + p0.norm()
635+
y1 = f1 * p1[0] + f2 * p1[1]
636+
y2 = f1 * p2[0] + f2 * p2[1]
637+
638+
expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0]
639+
expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0]
640+
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]
641+
642+
shared_params = [p0, p0]
643+
mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), shared_params=shared_params)
644+
645+
assert_close(p0.grad, expected_grad_wrt_p0)
646+
assert_close(p1.grad, expected_grad_wrt_p1)
647+
assert_close(p2.grad, expected_grad_wrt_p2)
648+
649+
650+
def test_repeated_task_params():
651+
"""
652+
Tests that mtl_backward correctly works when some task-specific params are repeated for some
653+
task. Since these are tensors with respect to which we differentiate, to match the behavior of
654+
torch.autograd.backward, this repetition should not affect the result.
655+
"""
656+
657+
p0 = torch.tensor([1.0, 2.0], requires_grad=True)
658+
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
659+
p2 = torch.tensor([3.0, 4.0], requires_grad=True)
660+
661+
f1 = torch.tensor([-1.0, 1.0]) @ p0
662+
f2 = (p0**2).sum() + p0.norm()
663+
y1 = f1 * p1[0] + f2 * p1[1]
664+
y2 = f1 * p2[0] + f2 * p2[1]
665+
666+
expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0]
667+
expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0]
668+
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]
669+
670+
tasks_params = [[p1, p1], [p2]]
671+
mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), tasks_params=tasks_params)
672+
673+
assert_close(p0.grad, expected_grad_wrt_p0)
674+
assert_close(p1.grad, expected_grad_wrt_p1)
675+
assert_close(p2.grad, expected_grad_wrt_p2)

0 commit comments

Comments
 (0)