diff --git a/CHANGELOG.md b/CHANGELOG.md index 138b21e49..933db572d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ changes that do not affect the user. - Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to project onto the dual cone. This may minimally affect the output of these aggregators. +- Refactored internal verifications in the autojac engine so that they do not run at runtime + anymore. This should minimally improve the performance and reduce the memory usage of `backward` + and `mtl_backward`. ### Fixed - Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index e5eedd19b..76b61eb46 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -20,7 +20,7 @@ def __init__( self.retain_graph = retain_graph self.create_graph = create_graph - def _compute(self, tensors: _A) -> _A: + def __call__(self, tensors: _A) -> _A: tensor_outputs = [tensors[output] for output in self.outputs] differentiated_tuple = self._differentiate(tensor_outputs) @@ -38,12 +38,6 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...] tensor_outputs should be. """ - @property - def required_keys(self) -> set[Tensor]: + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: # outputs in the forward direction become inputs in the backward direction, and vice-versa - return set(self.outputs) - - @property - def output_keys(self) -> set[Tensor]: - # outputs in the forward direction become inputs in the backward direction, and vice-versa - return set(self.inputs) + return set(self.outputs), set(self.inputs) diff --git a/src/torchjd/autojac/_transform/accumulate.py b/src/torchjd/autojac/_transform/accumulate.py index e9b372ed4..f509a5dfb 100644 --- a/src/torchjd/autojac/_transform/accumulate.py +++ b/src/torchjd/autojac/_transform/accumulate.py @@ -10,7 +10,7 @@ class Accumulate(Transform[Gradients, EmptyTensorDict]): def __init__(self, required_keys: Iterable[Tensor]): self._required_keys = set(required_keys) - def _compute(self, gradients: Gradients) -> EmptyTensorDict: + def __call__(self, gradients: Gradients) -> EmptyTensorDict: """ Accumulates gradients with respect to keys in their ``.grad`` field. """ @@ -28,13 +28,8 @@ def _compute(self, gradients: Gradients) -> EmptyTensorDict: return EmptyTensorDict() - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys - - @property - def output_keys(self) -> set[Tensor]: - return set() + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return self._required_keys, set() def _check_expects_grad(tensor: Tensor) -> None: diff --git a/src/torchjd/autojac/_transform/aggregate.py b/src/torchjd/autojac/_transform/aggregate.py index 250f688e9..e57eef1b9 100644 --- a/src/torchjd/autojac/_transform/aggregate.py +++ b/src/torchjd/autojac/_transform/aggregate.py @@ -23,16 +23,11 @@ def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]): self._aggregator_str = str(aggregator) self.transform = reshape << aggregate_matrices << matrixify - def _compute(self, input: Jacobians) -> Gradients: + def __call__(self, input: Jacobians) -> Gradients: return self.transform(input) - @property - def required_keys(self) -> set[Tensor]: - return self.transform.required_keys - - @property - def output_keys(self) -> set[Tensor]: - return self.transform.output_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return self.transform.check_and_get_keys() class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]): @@ -40,7 +35,7 @@ def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]): self.key_order = ordered_set(key_order) self.aggregator = aggregator - def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors: + def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors: """ Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to @@ -53,13 +48,9 @@ def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors: ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order) return self._aggregate_group(ordered_matrices, self.aggregator) - @property - def required_keys(self) -> set[Tensor]: - return set(self.key_order) - - @property - def output_keys(self) -> set[Tensor]: - return set(self.key_order) + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + keys = set(self.key_order) + return keys, keys @staticmethod def _select_ordered_subdict( @@ -120,36 +111,26 @@ class _Matrixify(Transform[Jacobians, JacobianMatrices]): def __init__(self, required_keys: Iterable[Tensor]): self._required_keys = set(required_keys) - def _compute(self, jacobians: Jacobians) -> JacobianMatrices: + def __call__(self, jacobians: Jacobians) -> JacobianMatrices: jacobian_matrices = { key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items() } return JacobianMatrices(jacobian_matrices) - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys - - @property - def output_keys(self) -> set[Tensor]: - return self._required_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return self._required_keys, self._required_keys class _Reshape(Transform[GradientVectors, Gradients]): def __init__(self, required_keys: Iterable[Tensor]): self._required_keys = set(required_keys) - def _compute(self, gradient_vectors: GradientVectors) -> Gradients: + def __call__(self, gradient_vectors: GradientVectors) -> Gradients: gradients = { key: gradient_vector.view(key.shape) for key, gradient_vector in gradient_vectors.items() } return Gradients(gradients) - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys - - @property - def output_keys(self) -> set[Tensor]: - return self._required_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return self._required_keys, self._required_keys diff --git a/src/torchjd/autojac/_transform/base.py b/src/torchjd/autojac/_transform/base.py index b1588f92a..4b2922d4f 100644 --- a/src/torchjd/autojac/_transform/base.py +++ b/src/torchjd/autojac/_transform/base.py @@ -40,24 +40,22 @@ def __str__(self) -> str: return type(self).__name__ @abstractmethod - def _compute(self, input: _B) -> _C: - """Applies the transform to the input.""" - def __call__(self, input: _B) -> _C: - input.check_keys_are(self.required_keys) - return self._compute(input) + """Applies the transform to the input.""" - @property @abstractmethod - def required_keys(self) -> set[Tensor]: - """ - Returns the set of keys that the transform requires to be present in its input TensorDicts. + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: """ + Returns a pair containing (in order) the required keys and the output keys of the Transform + and recursively checks that the transform is valid. - @property - @abstractmethod - def output_keys(self) -> set[Tensor]: - """Returns the set of keys that will be present in the output of the transform.""" + The required keys are the set of keys that the transform requires to be present in its input + TensorDicts. The output keys are the set of keys that will be present in the output + TensorDicts of the transform. + + Since the computation of the required and output keys and the verification that the + transform is valid are sometimes intertwined operations, we do them in a single method. + """ __lshift__ = compose __or__ = conjunct @@ -65,47 +63,31 @@ def output_keys(self) -> set[Tensor]: class Composition(Transform[_A, _C]): def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]): - if outer.required_keys != inner.output_keys: - raise ValueError( - "The `output_keys` of `inner` must match with the `required_keys` of " - f"outer. Found {outer.required_keys} and {inner.output_keys}" - ) self.outer = outer self.inner = inner def __str__(self) -> str: return str(self.outer) + " ∘ " + str(self.inner) - def _compute(self, input: _A) -> _C: + def __call__(self, input: _A) -> _C: intermediate = self.inner(input) return self.outer(intermediate) - @property - def required_keys(self) -> set[Tensor]: - return self.inner.required_keys - - @property - def output_keys(self) -> set[Tensor]: - return self.outer.output_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + outer_required_keys, outer_output_keys = self.outer.check_and_get_keys() + inner_required_keys, inner_output_keys = self.inner.check_and_get_keys() + if outer_required_keys != inner_output_keys: + raise ValueError( + "The `output_keys` of `inner` must match with the `required_keys` of " + f"outer. Found {outer_required_keys} and {inner_output_keys}" + ) + return inner_required_keys, outer_output_keys class Conjunction(Transform[_A, _B]): def __init__(self, transforms: Sequence[Transform[_A, _B]]): self.transforms = transforms - self._required_keys = set( - key for transform in transforms for key in transform.required_keys - ) - for transform in transforms: - if transform.required_keys != self.required_keys: - raise ValueError("All transforms should require the same set of keys.") - - output_keys_with_duplicates = [key for t in transforms for key in t.output_keys] - self._output_keys = set(output_keys_with_duplicates) - - if len(self._output_keys) != len(output_keys_with_duplicates): - raise ValueError("The sets of output keys of transforms should be disjoint.") - def __str__(self) -> str: strings = [] for t in self.transforms: @@ -116,14 +98,22 @@ def __str__(self) -> str: strings.append(s) return "(" + " | ".join(strings) + ")" - def _compute(self, tensor_dict: _A) -> _B: + def __call__(self, tensor_dict: _A) -> _B: output = _union([transform(tensor_dict) for transform in self.transforms]) return output - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + keys_pairs = [transform.check_and_get_keys() for transform in self.transforms] + + required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys) + for transform_required_keys, _ in keys_pairs: + if transform_required_keys != required_keys: + raise ValueError("All transforms should require the same set of keys.") + + output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys] + output_keys = set(output_keys_with_duplicates) + + if len(output_keys) != len(output_keys_with_duplicates): + raise ValueError("The sets of output keys of transforms should be disjoint.") - @property - def output_keys(self) -> set[Tensor]: - return self._output_keys + return required_keys, output_keys diff --git a/src/torchjd/autojac/_transform/diagonalize.py b/src/torchjd/autojac/_transform/diagonalize.py index 7faf5a206..c7ff9fa4e 100644 --- a/src/torchjd/autojac/_transform/diagonalize.py +++ b/src/torchjd/autojac/_transform/diagonalize.py @@ -18,7 +18,7 @@ def __init__(self, considered: Iterable[Tensor]): self.indices.append((begin, end)) begin = end - def _compute(self, tensors: Gradients) -> Jacobians: + def __call__(self, tensors: Gradients) -> Jacobians: flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered] diagonal_matrix = torch.cat(flattened_considered_values).diag() diagonalized_tensors = { @@ -27,10 +27,6 @@ def _compute(self, tensors: Gradients) -> Jacobians: } return Jacobians(diagonalized_tensors) - @property - def required_keys(self) -> set[Tensor]: - return set(self.considered) - - @property - def output_keys(self) -> set[Tensor]: - return set(self.considered) + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + keys = set(self.considered) + return keys, keys diff --git a/src/torchjd/autojac/_transform/init.py b/src/torchjd/autojac/_transform/init.py index 7afafff39..c42332944 100644 --- a/src/torchjd/autojac/_transform/init.py +++ b/src/torchjd/autojac/_transform/init.py @@ -11,7 +11,7 @@ class Init(Transform[EmptyTensorDict, Gradients]): def __init__(self, values: Iterable[Tensor]): self.values = set(values) - def _compute(self, input: EmptyTensorDict) -> Gradients: + def __call__(self, input: EmptyTensorDict) -> Gradients: r""" Computes the gradients of the ``value`` with respect to itself. Returns the result as a dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a @@ -21,10 +21,5 @@ def _compute(self, input: EmptyTensorDict) -> Gradients: return Gradients({value: torch.ones_like(value) for value in self.values}) - @property - def required_keys(self) -> set[Tensor]: - return set() - - @property - def output_keys(self) -> set[Tensor]: - return self.values + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return set(), self.values diff --git a/src/torchjd/autojac/_transform/select.py b/src/torchjd/autojac/_transform/select.py index fa72a83ef..50a691b75 100644 --- a/src/torchjd/autojac/_transform/select.py +++ b/src/torchjd/autojac/_transform/select.py @@ -11,17 +11,13 @@ def __init__(self, keys: Iterable[Tensor], required_keys: Iterable[Tensor]): self.keys = set(keys) self._required_keys = set(required_keys) - if not self.keys.issubset(self._required_keys): - raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`") - - def _compute(self, tensor_dict: _A) -> _A: + def __call__(self, tensor_dict: _A) -> _A: output = {key: tensor_dict[key] for key in self.keys} return type(tensor_dict)(output) - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + required_keys = self._required_keys + if not self.keys.issubset(required_keys): + raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`") - @property - def output_keys(self) -> set[Tensor]: - return self.keys + return required_keys, self.keys diff --git a/src/torchjd/autojac/_transform/stack.py b/src/torchjd/autojac/_transform/stack.py index 527186b7b..6e1daebe2 100644 --- a/src/torchjd/autojac/_transform/stack.py +++ b/src/torchjd/autojac/_transform/stack.py @@ -12,25 +12,22 @@ class Stack(Transform[_A, Jacobians]): def __init__(self, transforms: Sequence[Transform[_A, Gradients]]): self.transforms = transforms - self._required_keys = {key for transform in transforms for key in transform.required_keys} - self._output_keys = {key for transform in transforms for key in transform.output_keys} - - for transform in transforms: - if transform.required_keys != self.required_keys: - raise ValueError("All transforms should require the same set of keys.") - - def _compute(self, input: _A) -> Jacobians: + def __call__(self, input: _A) -> Jacobians: results = [transform(input) for transform in self.transforms] result = _stack(results) return result - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + keys_pairs = [transform.check_and_get_keys() for transform in self.transforms] + + required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys) + output_keys = set(key for _, output_keys in keys_pairs for key in output_keys) + + for transform_required_keys, _ in keys_pairs: + if transform_required_keys != required_keys: + raise ValueError("All transforms should require the same set of keys.") - @property - def output_keys(self) -> set[Tensor]: - return self._output_keys + return required_keys, output_keys def _stack(gradient_dicts: list[Gradients]) -> Jacobians: diff --git a/src/torchjd/autojac/_transform/tensor_dict.py b/src/torchjd/autojac/_transform/tensor_dict.py index 3fcf034c6..9e89d6476 100644 --- a/src/torchjd/autojac/_transform/tensor_dict.py +++ b/src/torchjd/autojac/_transform/tensor_dict.py @@ -14,19 +14,6 @@ def __init__(self, tensor_dict: dict[Tensor, Tensor]): self._check_all_pairs(tensor_dict) super().__init__(tensor_dict) - def check_keys_are(self, keys: set[Tensor]) -> None: - """ - Checks that the keys in the mapping are the same as the provided ``keys``. - - :param keys: Keys that the mapping should (exclusively) contain. - """ - - if set(keys) != set(self.keys()): - raise ValueError( - f"The keys of the {self.__class__.__name__} should be {keys}. Found self.keys = " - f"{self.keys()}." - ) - @staticmethod def _check_dict(tensor_dict: dict[Tensor, Tensor]) -> None: pass diff --git a/src/torchjd/autojac/backward.py b/src/torchjd/autojac/backward.py index 01b562390..daf1e1b02 100644 --- a/src/torchjd/autojac/backward.py +++ b/src/torchjd/autojac/backward.py @@ -4,7 +4,7 @@ from torchjd.aggregation import Aggregator -from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac +from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors @@ -78,6 +78,26 @@ def backward( else: inputs = set(inputs) + backward_transform = _create_transform( + tensors=tensors, + aggregator=aggregator, + inputs=inputs, + retain_graph=retain_graph, + parallel_chunk_size=parallel_chunk_size, + ) + + backward_transform(EmptyTensorDict()) + + +def _create_transform( + tensors: list[Tensor], + aggregator: Aggregator, + inputs: set[Tensor], + retain_graph: bool, + parallel_chunk_size: int | None, +) -> Transform[EmptyTensorDict, EmptyTensorDict]: + """Creates the Jacobian descent backward transform.""" + # Transform that creates gradient outputs containing only ones. init = Init(tensors) @@ -93,6 +113,4 @@ def backward( # Transform that accumulates the result in the .grad field of the inputs. accumulate = Accumulate(inputs) - backward_transform = accumulate << aggregate << jac << diag << init - - backward_transform(EmptyTensorDict()) + return accumulate << aggregate << jac << diag << init diff --git a/src/torchjd/autojac/mtl_backward.py b/src/torchjd/autojac/mtl_backward.py index 78e357092..c1ed0c48a 100644 --- a/src/torchjd/autojac/mtl_backward.py +++ b/src/torchjd/autojac/mtl_backward.py @@ -98,6 +98,34 @@ def mtl_backward( if len(losses) != len(tasks_params): raise ValueError("`losses` and `tasks_params` should have the same size.") + backward_transform = _create_transform( + losses=losses, + features=features, + aggregator=aggregator, + tasks_params=tasks_params, + shared_params=shared_params, + retain_graph=retain_graph, + parallel_chunk_size=parallel_chunk_size, + ) + + backward_transform(EmptyTensorDict()) + + +def _create_transform( + losses: Sequence[Tensor], + features: list[Tensor], + aggregator: Aggregator, + tasks_params: list[Iterable[Tensor]], + shared_params: set[Tensor], + retain_graph: bool, + parallel_chunk_size: int | None, +) -> Transform[EmptyTensorDict, EmptyTensorDict]: + """ + Creates the backward transform for a multi-task learning problem. It is a hybrid between + Jacobian descent (for shared parameters) and multiple gradient descent branches (for + task-specific parameters). + """ + shared_params = list(shared_params) tasks_params = [list(task_params) for task_params in tasks_params] @@ -105,7 +133,7 @@ def mtl_backward( # loss w.r.t. the task's specific parameters, and computes and backpropagates the gradient of # the losses w.r.t. the shared representations. task_transforms = [ - _make_task_transform( + _create_task_transform( features, task_params, loss, @@ -127,12 +155,10 @@ def mtl_backward( # Transform that accumulates the result in the .grad field of the shared parameters. accumulate = Accumulate(shared_params) - backward_transform = accumulate << aggregate << jac << stack - - backward_transform(EmptyTensorDict()) + return accumulate << aggregate << jac << stack -def _make_task_transform( +def _create_task_transform( features: list[Tensor], tasks_params: list[Tensor], loss: Tensor, diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index a1b89fba9..8967d3646 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -93,3 +93,15 @@ def test_no_leaf_and_no_retains_grad_fails(): with raises(ValueError): accumulate(input) + + +def test_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + key = torch.tensor([1.0], requires_grad=True) + accumulate = Accumulate([key]) + + required_keys, output_keys = accumulate.check_and_get_keys() + + assert required_keys == {key} + assert output_keys == set() diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py index 31a6f952a..992bd330e 100644 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ b/tests/unit/autojac/_transform/test_aggregate.py @@ -142,3 +142,42 @@ def test_reshape(): } assert_tensor_dicts_are_close(output, expected_output) + + +def test_aggregate_matrices_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + key1 = torch.tensor([1.0]) + key2 = torch.tensor([2.0]) + aggregate = _AggregateMatrices(Random(), [key2, key1]) + + required_keys, output_keys = aggregate.check_and_get_keys() + + assert required_keys == {key1, key2} + assert output_keys == {key1, key2} + + +def test_matrixify_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + key1 = torch.tensor([1.0]) + key2 = torch.tensor([2.0]) + matrixify = _Matrixify([key1, key2]) + + required_keys, output_keys = matrixify.check_and_get_keys() + + assert required_keys == {key1, key2} + assert output_keys == {key1, key2} + + +def test_reshape_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + key1 = torch.tensor([1.0]) + key2 = torch.tensor([2.0]) + reshape = _Reshape([key1, key2]) + + required_keys, output_keys = reshape.check_and_get_keys() + + assert required_keys == {key1, key2} + assert output_keys == {key1, key2} diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index c93b30721..5d2864e8b 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -21,47 +21,20 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): def __str__(self): return "T" - def _compute(self, input: _B) -> _C: + def __call__(self, input: _B) -> _C: # Ignore the input, create a dictionary with the right keys as an output. # Cast the type for the purpose of type-checking. output_dict = {key: torch.empty(0) for key in self._output_keys} return typing.cast(_C, output_dict) - @property - def required_keys(self) -> set[Tensor]: - return self._required_keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return self._required_keys, self._output_keys - @property - def output_keys(self) -> set[Tensor]: - return self._output_keys - -def test_call_checks_keys(): - """ - Tests that a ``Transform`` checks that the provided dictionary to the `__call__` function - contains keys that correspond exactly to `required_keys`. - """ - - a1 = torch.randn([2]) - a2 = torch.randn([3]) - t = FakeTransform(required_keys={a1}, output_keys={a1, a2}) - - t(TensorDict({a1: a2})) - - with raises(ValueError): - t(TensorDict({a2: a1})) - - with raises(ValueError): - t(TensorDict({})) - - with raises(ValueError): - t(TensorDict({a1: a2, a2: a1})) - - -def test_compose_checks_keys(): +def test_composition_check_and_get_keys(): """ - Tests that the composition of ``Transform``s checks that the inner transform's `output_keys` - match with the outer transform's `required_keys`. + Tests that `check_and_get_keys` works correctly for a composition of transforms: the inner + transform's `output_keys` has to match with the outer transform's `required_keys`. """ a1 = torch.randn([2]) @@ -69,16 +42,19 @@ def test_compose_checks_keys(): t1 = FakeTransform(required_keys={a1}, output_keys={a1, a2}) t2 = FakeTransform(required_keys={a2}, output_keys={a1}) - t1 << t2 + required_keys, output_keys = (t1 << t2).check_and_get_keys() + + assert required_keys == {a2} + assert output_keys == {a1, a2} with raises(ValueError): - t2 << t1 + (t2 << t1).check_and_get_keys() -def test_conjunct_checks_required_keys(): +def test_conjunct_check_and_get_keys_1(): """ - Tests that the conjunction of ``Transform``s checks that the provided transforms all have the - same `required_keys`. + Tests that `check_and_get_keys` works correctly for a conjunction of transforms: all of them + should have the same `required_keys`. """ a1 = torch.randn([2]) @@ -88,19 +64,22 @@ def test_conjunct_checks_required_keys(): t2 = FakeTransform(required_keys={a1}, output_keys=set()) t3 = FakeTransform(required_keys={a2}, output_keys=set()) - t1 | t2 + required_keys, output_keys = (t1 | t2).check_and_get_keys() + + assert required_keys == {a1} + assert output_keys == set() with raises(ValueError): - t2 | t3 + (t2 | t3).check_and_get_keys() with raises(ValueError): - t1 | t2 | t3 + (t1 | t2 | t3).check_and_get_keys() -def test_conjunct_checks_output_keys(): +def test_conjunct_check_and_get_keys_2(): """ - Tests that the conjunction of ``Transform``s checks that the transforms `output_keys` are - disjoint. + Tests that `check_and_get_keys` works correctly for a conjunction of transforms: their + `output_keys` should be disjoint. """ a1 = torch.randn([2]) @@ -110,13 +89,16 @@ def test_conjunct_checks_output_keys(): t2 = FakeTransform(required_keys=set(), output_keys={a1}) t3 = FakeTransform(required_keys=set(), output_keys={a2}) - t2 | t3 + required_keys, output_keys = (t2 | t3).check_and_get_keys() + + assert required_keys == set() + assert output_keys == {a1, a2} with raises(ValueError): - t1 | t3 + (t1 | t3).check_and_get_keys() with raises(ValueError): - t1 | t2 | t3 + (t1 | t2 | t3).check_and_get_keys() def test_empty_conjunction(): diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index 320b68098..5e4cc74ef 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -95,3 +95,15 @@ def test_permute_order(): expected_output = diag(input) assert_tensor_dicts_are_close(output, expected_output) + + +def test_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + key = torch.tensor([1.0]) + diag = Diagonalize([key]) + + required_keys, output_keys = diag.check_and_get_keys() + + assert required_keys == {key} + assert output_keys == {key} diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index cef17d590..f8d6dcead 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -281,3 +281,19 @@ def test_create_graph(): gradients = grad(input) assert gradients[a].requires_grad + + +def test_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + x = torch.tensor(5.0) + a1 = torch.tensor(2.0, requires_grad=True) + a2 = torch.tensor(3.0, requires_grad=True) + y = torch.stack([a1 * x, a2 * x]) + + grad = Grad(outputs=[y], inputs=[a1, a2]) + + required_keys, output_keys = grad.check_and_get_keys() + + assert required_keys == {y} + assert output_keys == {a1, a2} diff --git a/tests/unit/autojac/_transform/test_init.py b/tests/unit/autojac/_transform/test_init.py index 1d5bc5689..07b410de3 100644 --- a/tests/unit/autojac/_transform/test_init.py +++ b/tests/unit/autojac/_transform/test_init.py @@ -61,3 +61,15 @@ def test_conjunction_of_inits_is_init(): expected_output = init(input) assert_tensor_dicts_are_close(output, expected_output) + + +def test_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + key = torch.tensor([1.0]) + init = Init([key]) + + required_keys, output_keys = init.check_and_get_keys() + + assert required_keys == set() + assert output_keys == {key} diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 27db5b01f..2cb3e846f 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -251,15 +251,24 @@ def test_equivalence_jac_grads(): assert_close(jac_c, torch.stack([grad_1_c, grad_2_c])) -def test_stack_different_required_keys(): - """Tests that the Stack transform fails on transforms with different required keys.""" +def test_stack_check_and_get_keys(): + """ + Tests that the `check_and_get_keys` method works correctly for a stack of transforms: all of + them should have the same `required_keys`. + """ a = torch.tensor(1.0, requires_grad=True) y1 = a * 2.0 y2 = a * 3.0 grad1 = Grad([y1], [a]) - grad2 = Grad([y2], [a]) + grad2 = Grad([y1], [a]) + grad3 = Grad([y2], [a]) + + required_keys, output_keys = Stack([grad1, grad2]).check_and_get_keys() + + assert required_keys == {y1} + assert output_keys == {a} with raises(ValueError): - _ = Stack([grad1, grad2]) + Stack([grad1, grad3]).check_and_get_keys() diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index 54ac47e34..32121f068 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -281,3 +281,19 @@ def test_create_graph(): assert jacobians[a1].requires_grad assert jacobians[a2].requires_grad + + +def test_check_and_get_keys(): + """Tests that the `check_and_get_keys` method works correctly.""" + + x = torch.tensor(5.0) + a1 = torch.tensor(2.0, requires_grad=True) + a2 = torch.tensor(3.0, requires_grad=True) + y = torch.stack([a1 * x, a2 * x]) + + jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=None) + + required_keys, output_keys = jac.check_and_get_keys() + + assert required_keys == {y} + assert output_keys == {a1, a2} diff --git a/tests/unit/autojac/_transform/test_select.py b/tests/unit/autojac/_transform/test_select.py index 7cf37b165..a6fa23f6b 100644 --- a/tests/unit/autojac/_transform/test_select.py +++ b/tests/unit/autojac/_transform/test_select.py @@ -1,8 +1,5 @@ -from contextlib import nullcontext as does_not_raise - import torch -from pytest import mark, raises -from unit._utils import ExceptionContext +from pytest import raises from torchjd.autojac._transform import Select, TensorDict @@ -59,26 +56,20 @@ def test_conjunction_of_selects_is_select(): assert_tensor_dicts_are_close(output, expected_output) -@mark.parametrize( - ["key_indices", "required_key_indices", "expectation"], - [ - ([0], [0, 1], does_not_raise()), - ([0], [1], raises(ValueError)), - ([0, 1], [0], raises(ValueError)), - ([], [0], does_not_raise()), - ], -) -def test_keys_check( - key_indices: list[int], required_key_indices: list[int], expectation: ExceptionContext -): +def test_check_and_get_keys(): """ - Tests that the Select transform correctly checks that the keys are a subset of the required - keys. + Tests that the `check_and_get_keys` method works correctly: the set of keys to select should + be a subset of the set of required_keys. """ - all_keys = [torch.tensor(i) for i in range(2)] - keys = [all_keys[i] for i in key_indices] - required_keys = [all_keys[i] for i in required_key_indices] + key1 = torch.tensor([1.0]) + key2 = torch.tensor([2.0]) + key3 = torch.tensor([3.0]) + + required_keys, output_keys = Select([key1, key2], [key1, key2, key3]).check_and_get_keys() + + assert required_keys == {key1, key2, key3} + assert output_keys == {key1, key2} - with expectation: - _ = Select(keys, required_keys) + with raises(ValueError): + Select([key1, key2], [key1]).check_and_get_keys() diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index fdb45008c..692fabbdb 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -17,16 +17,11 @@ class FakeGradientsTransform(Transform[EmptyTensorDict, Gradients]): def __init__(self, keys: Iterable[Tensor]): self.keys = set(keys) - def _compute(self, input: EmptyTensorDict) -> Gradients: + def __call__(self, input: EmptyTensorDict) -> Gradients: return Gradients({key: torch.ones_like(key) for key in self.keys}) - @property - def required_keys(self) -> set[Tensor]: - return set() - - @property - def output_keys(self) -> set[Tensor]: - return self.keys + def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: + return set(), self.keys def test_single_key(): diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index f8f16c8c2..fab1877e6 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -5,6 +5,29 @@ from torchjd import backward from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad +from torchjd.autojac.backward import _create_transform + + +def test_check_create_transform(): + """Tests that _create_transform creates a valid Transform""" + + a1 = torch.tensor([1.0, 2.0], requires_grad=True) + a2 = torch.tensor([3.0, 4.0], requires_grad=True) + + y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum() + y2 = (a1**2).sum() + a2.norm() + + transform = _create_transform( + tensors=[y1, y2], + aggregator=Mean(), + inputs={a1, a2}, + retain_graph=False, + parallel_chunk_size=None, + ) + required_keys, output_keys = transform.check_and_get_keys() + + assert required_keys == set() + assert output_keys == set() @mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 6f98a4ac4..9fd4e3a93 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -5,6 +5,34 @@ from torchjd import mtl_backward from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad +from torchjd.autojac.mtl_backward import _create_transform + + +def test_check_create_transform(): + """Tests that _create_transform creates a valid Transform""" + + p0 = torch.tensor([1.0, 2.0], requires_grad=True) + p1 = torch.tensor([1.0, 2.0], requires_grad=True) + p2 = torch.tensor([3.0, 4.0], requires_grad=True) + + f1 = torch.tensor([-1.0, 1.0]) @ p0 + f2 = (p0**2).sum() + p0.norm() + y1 = f1 * p1[0] + f2 * p1[1] + y2 = f1 * p2[0] + f2 * p2[1] + + transform = _create_transform( + losses=[y1, y2], + features=[f1, f2], + aggregator=Mean(), + tasks_params=[[p1], [p2]], + shared_params={p0}, + retain_graph=False, + parallel_chunk_size=None, + ) + required_keys, output_keys = transform.check_and_get_keys() + + assert required_keys == set() + assert output_keys == set() @mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])