diff --git a/CHANGELOG.md b/CHANGELOG.md index faa29afc5..138b21e49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,11 @@ changes that do not affect the user. project onto the dual cone. This may minimally affect the output of these aggregators. ### Fixed +- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they + appear several times in a list of tensors provided as argument). Instead of raising an exception + in these cases, we are now aligned with the behavior of `torch.autograd.backward`. Repeated + tensors that we differentiate lead to repeated rows in the Jacobian, prior to aggregation, and + repeated tensors with respect to which we differentiate count only once. - Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In practice, this fix should only affect some matrices with extremely large values, which should not usually happen. diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index ec30aa77a..e5eedd19b 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -15,7 +15,7 @@ def __init__( retain_graph: bool, create_graph: bool, ): - self.outputs = ordered_set(outputs) + self.outputs = list(outputs) self.inputs = ordered_set(inputs) self.retain_graph = retain_graph self.create_graph = create_graph diff --git a/src/torchjd/autojac/_transform/_utils.py b/src/torchjd/autojac/_transform/_utils.py index 47639ac3c..3339ca997 100644 --- a/src/torchjd/autojac/_transform/_utils.py +++ b/src/torchjd/autojac/_transform/_utils.py @@ -16,14 +16,7 @@ def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]: - elements = list(elements) - result = OrderedDict.fromkeys(elements, None) - if len(elements) != len(result): - raise ValueError( - f"Parameter `elements` should contain unique elements. Found `elements = {elements}`." - ) - - return result + return OrderedDict.fromkeys(elements, None) def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]: diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index ab4075fb9..f8f16c8c2 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -1,9 +1,10 @@ import torch from pytest import mark, raises +from torch.autograd import grad from torch.testing import assert_close from torchjd import backward -from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad +from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad @mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) @@ -214,3 +215,47 @@ def test_tensor_used_multiple_times(chunk_size: int | None): ) assert_close(a.grad, aggregator(expected_jacobian).squeeze()) + + +def test_repeated_tensors(): + """ + Tests that backward correctly works when some tensors are repeated. In this case, since + torch.autograd.backward would sum the gradients of the repeated tensors, it is natural for + autojac to compute a Jacobian with one row per repeated tensor, and to aggregate it. + """ + + a1 = torch.tensor([1.0, 2.0], requires_grad=True) + a2 = torch.tensor([3.0, 4.0], requires_grad=True) + + y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() + y2 = (a1**2).sum() + (a2**2).sum() + + expected_grad_wrt_a1 = grad([y1, y1, y2], a1, retain_graph=True)[0] + expected_grad_wrt_a2 = grad([y1, y1, y2], a2, retain_graph=True)[0] + + backward([y1, y1, y2], Sum()) + + assert_close(a1.grad, expected_grad_wrt_a1) + assert_close(a2.grad, expected_grad_wrt_a2) + + +def test_repeated_inputs(): + """ + Tests that backward correctly works when some inputs are repeated. In this case, since + torch.autograd.backward ignores the repetition of the inputs, it is natural for autojac to + ignore that as well. + """ + + a1 = torch.tensor([1.0, 2.0], requires_grad=True) + a2 = torch.tensor([3.0, 4.0], requires_grad=True) + + y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() + y2 = (a1**2).sum() + (a2**2).sum() + + expected_grad_wrt_a1 = grad([y1, y2], a1, retain_graph=True)[0] + expected_grad_wrt_a2 = grad([y1, y2], a2, retain_graph=True)[0] + + backward([y1, y2], Sum(), inputs=[a1, a1, a2]) + + assert_close(a1.grad, expected_grad_wrt_a1) + assert_close(a2.grad, expected_grad_wrt_a2) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 21dcbca07..6f98a4ac4 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -1,9 +1,10 @@ import torch from pytest import mark, raises +from torch.autograd import grad from torch.testing import assert_close from torchjd import mtl_backward -from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad +from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad @mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) @@ -557,3 +558,118 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): features=[f], aggregator=UPGrad(), ) + + +def test_repeated_losses(): + """ + Tests that mtl_backward correctly works when some losses are repeated. In this case, since + torch.autograd.backward would sum the gradients of the repeated losses, it is natural for + autojac to sum the task-specific gradients, and to compute and aggregate a Jacobian with one row + per repeated tensor, for shared gradients. + """ + + p0 = torch.tensor([1.0, 2.0], requires_grad=True) + p1 = torch.tensor([1.0, 2.0], requires_grad=True) + p2 = torch.tensor([3.0, 4.0], requires_grad=True) + + f1 = torch.tensor([-1.0, 1.0]) @ p0 + f2 = (p0**2).sum() + p0.norm() + y1 = f1 * p1[0] + f2 * p1[1] + y2 = f1 * p2[0] + f2 * p2[1] + + expected_grad_wrt_p0 = grad([y1, y1, y2], [p0], retain_graph=True)[0] + expected_grad_wrt_p1 = grad([y1, y1], [p1], retain_graph=True)[0] + expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + + losses = [y1, y1, y2] + mtl_backward(losses=losses, features=[f1, f2], aggregator=Sum(), retain_graph=True) + + assert_close(p0.grad, expected_grad_wrt_p0) + assert_close(p1.grad, expected_grad_wrt_p1) + assert_close(p2.grad, expected_grad_wrt_p2) + + +def test_repeated_features(): + """ + Tests that mtl_backward correctly works when some features are repeated. Repeated features are + a bit more tricky, because we differentiate with respect to them (in which case it shouldn't + matter that they are repeated) and we also differentiate them (in which case it should lead to + extra rows in the Jacobian). + """ + + p0 = torch.tensor([1.0, 2.0], requires_grad=True) + p1 = torch.tensor([1.0, 2.0], requires_grad=True) + p2 = torch.tensor([3.0, 4.0], requires_grad=True) + + f1 = torch.tensor([-1.0, 1.0]) @ p0 + f2 = (p0**2).sum() + p0.norm() + y1 = f1 * p1[0] + f2 * p1[1] + y2 = f1 * p2[0] + f2 * p2[1] + + grad_outputs = grad([y1, y2], [f1, f1, f2], retain_graph=True) + expected_grad_wrt_p0 = grad([f1, f1, f2], [p0], grad_outputs, retain_graph=True)[0] + expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] + expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + + features = [f1, f1, f2] + mtl_backward(losses=[y1, y2], features=features, aggregator=Sum()) + + assert_close(p0.grad, expected_grad_wrt_p0) + assert_close(p1.grad, expected_grad_wrt_p1) + assert_close(p2.grad, expected_grad_wrt_p2) + + +def test_repeated_shared_params(): + """ + Tests that mtl_backward correctly works when some shared are repeated. Since these are tensors + with respect to which we differentiate, to match the behavior of torch.autograd.backward, this + repetition should not affect the result. + """ + + p0 = torch.tensor([1.0, 2.0], requires_grad=True) + p1 = torch.tensor([1.0, 2.0], requires_grad=True) + p2 = torch.tensor([3.0, 4.0], requires_grad=True) + + f1 = torch.tensor([-1.0, 1.0]) @ p0 + f2 = (p0**2).sum() + p0.norm() + y1 = f1 * p1[0] + f2 * p1[1] + y2 = f1 * p2[0] + f2 * p2[1] + + expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0] + expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] + expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + + shared_params = [p0, p0] + mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), shared_params=shared_params) + + assert_close(p0.grad, expected_grad_wrt_p0) + assert_close(p1.grad, expected_grad_wrt_p1) + assert_close(p2.grad, expected_grad_wrt_p2) + + +def test_repeated_task_params(): + """ + Tests that mtl_backward correctly works when some task-specific params are repeated for some + task. Since these are tensors with respect to which we differentiate, to match the behavior of + torch.autograd.backward, this repetition should not affect the result. + """ + + p0 = torch.tensor([1.0, 2.0], requires_grad=True) + p1 = torch.tensor([1.0, 2.0], requires_grad=True) + p2 = torch.tensor([3.0, 4.0], requires_grad=True) + + f1 = torch.tensor([-1.0, 1.0]) @ p0 + f2 = (p0**2).sum() + p0.norm() + y1 = f1 * p1[0] + f2 * p1[1] + y2 = f1 * p2[0] + f2 * p2[1] + + expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0] + expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] + expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + + tasks_params = [[p1, p1], [p2]] + mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), tasks_params=tasks_params) + + assert_close(p0.grad, expected_grad_wrt_p0) + assert_close(p1.grad, expected_grad_wrt_p1) + assert_close(p2.grad, expected_grad_wrt_p2)