Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3f11448
Decouple construction and checking in Transforms
PierreQuinton Mar 28, 2025
5a97c18
add check_keys in test_stack.FakeGradientsTransform
PierreQuinton Mar 28, 2025
c804e5e
Fix tests checking Transforms
PierreQuinton Mar 28, 2025
ea5dfac
remove test_call_checks_keys
PierreQuinton Mar 28, 2025
8adad9d
Add changelog entry
ValerianRey Mar 28, 2025
ee3e099
Extract _create_transform from backward
PierreQuinton Mar 28, 2025
ac2b31b
Test _create_transform
PierreQuinton Mar 28, 2025
4979038
Fix mtl_backward test of create transform and check keys
PierreQuinton Mar 28, 2025
df6c200
Remove _compute, use __call__ instead
ValerianRey Mar 28, 2025
8d3d98d
Same as before for FakeTransforms
ValerianRey Mar 28, 2025
69b2f75
Rename check_keys to check_and_get_keys
ValerianRey Mar 28, 2025
f278a5c
Improve docstring of check_and_get_keys
ValerianRey Mar 28, 2025
316d3dd
Fix tests of check_and_get_keys
ValerianRey Mar 28, 2025
fcd1e3b
Move check and cast to set from __init__ to check_and_get_keys in Select
ValerianRey Mar 28, 2025
b518046
Change test_keys_check into test_check_and_get_keys in Select
ValerianRey Mar 28, 2025
f490089
Add missing tests of check_and_get_keys
ValerianRey Mar 28, 2025
b64665b
revert cast to set in Select
PierreQuinton Mar 28, 2025
ab912a1
Remove check_keys_are
ValerianRey Mar 28, 2025
b744af0
Remove useless casting to set in Select.check_and_get_keys
ValerianRey Mar 28, 2025
4f6e515
Add docstring to mtl_backward._create_transform
ValerianRey Mar 28, 2025
e2b3374
Rename _make_task_transform to _create_task_transform
ValerianRey Mar 28, 2025
ffb6d55
Uniformize test_check_create_transform for backward and mtl_backward …
ValerianRey Mar 28, 2025
e72d3ca
Explicitly give parameter names in calls to _create_transform
ValerianRey Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ changes that do not affect the user.

- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
project onto the dual cone. This may minimally affect the output of these aggregators.
- Refactored internal verifications in the autojac engine so that they do not run at runtime
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
and `mtl_backward`.

### Fixed
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they
Expand Down
12 changes: 3 additions & 9 deletions src/torchjd/autojac/_transform/_differentiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
self.retain_graph = retain_graph
self.create_graph = create_graph

def _compute(self, tensors: _A) -> _A:
def __call__(self, tensors: _A) -> _A:
tensor_outputs = [tensors[output] for output in self.outputs]

differentiated_tuple = self._differentiate(tensor_outputs)
Expand All @@ -38,12 +38,6 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]
tensor_outputs should be.
"""

@property
def required_keys(self) -> set[Tensor]:
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
# outputs in the forward direction become inputs in the backward direction, and vice-versa
return set(self.outputs)

@property
def output_keys(self) -> set[Tensor]:
# outputs in the forward direction become inputs in the backward direction, and vice-versa
return set(self.inputs)
return set(self.outputs), set(self.inputs)
11 changes: 3 additions & 8 deletions src/torchjd/autojac/_transform/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Accumulate(Transform[Gradients, EmptyTensorDict]):
def __init__(self, required_keys: Iterable[Tensor]):
self._required_keys = set(required_keys)

def _compute(self, gradients: Gradients) -> EmptyTensorDict:
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
"""
Accumulates gradients with respect to keys in their ``.grad`` field.
"""
Expand All @@ -28,13 +28,8 @@ def _compute(self, gradients: Gradients) -> EmptyTensorDict:

return EmptyTensorDict()

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys

@property
def output_keys(self) -> set[Tensor]:
return set()
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, set()


def _check_expects_grad(tensor: Tensor) -> None:
Expand Down
45 changes: 13 additions & 32 deletions src/torchjd/autojac/_transform/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,19 @@ def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
self._aggregator_str = str(aggregator)
self.transform = reshape << aggregate_matrices << matrixify

def _compute(self, input: Jacobians) -> Gradients:
def __call__(self, input: Jacobians) -> Gradients:
return self.transform(input)

@property
def required_keys(self) -> set[Tensor]:
return self.transform.required_keys

@property
def output_keys(self) -> set[Tensor]:
return self.transform.output_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self.transform.check_and_get_keys()


class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
self.key_order = ordered_set(key_order)
self.aggregator = aggregator

def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
"""
Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using
the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to
Expand All @@ -53,13 +48,9 @@ def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order)
return self._aggregate_group(ordered_matrices, self.aggregator)

@property
def required_keys(self) -> set[Tensor]:
return set(self.key_order)

@property
def output_keys(self) -> set[Tensor]:
return set(self.key_order)
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys = set(self.key_order)
return keys, keys

@staticmethod
def _select_ordered_subdict(
Expand Down Expand Up @@ -120,36 +111,26 @@ class _Matrixify(Transform[Jacobians, JacobianMatrices]):
def __init__(self, required_keys: Iterable[Tensor]):
self._required_keys = set(required_keys)

def _compute(self, jacobians: Jacobians) -> JacobianMatrices:
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
jacobian_matrices = {
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
}
return JacobianMatrices(jacobian_matrices)

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys

@property
def output_keys(self) -> set[Tensor]:
return self._required_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, self._required_keys


class _Reshape(Transform[GradientVectors, Gradients]):
def __init__(self, required_keys: Iterable[Tensor]):
self._required_keys = set(required_keys)

def _compute(self, gradient_vectors: GradientVectors) -> Gradients:
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
gradients = {
key: gradient_vector.view(key.shape)
for key, gradient_vector in gradient_vectors.items()
}
return Gradients(gradients)

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys

@property
def output_keys(self) -> set[Tensor]:
return self._required_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, self._required_keys
82 changes: 36 additions & 46 deletions src/torchjd/autojac/_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,72 +40,54 @@ def __str__(self) -> str:
return type(self).__name__

@abstractmethod
def _compute(self, input: _B) -> _C:
"""Applies the transform to the input."""

def __call__(self, input: _B) -> _C:
input.check_keys_are(self.required_keys)
return self._compute(input)
"""Applies the transform to the input."""

@property
@abstractmethod
def required_keys(self) -> set[Tensor]:
"""
Returns the set of keys that the transform requires to be present in its input TensorDicts.
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
"""
Returns a pair containing (in order) the required keys and the output keys of the Transform
and recursively checks that the transform is valid.

@property
@abstractmethod
def output_keys(self) -> set[Tensor]:
"""Returns the set of keys that will be present in the output of the transform."""
The required keys are the set of keys that the transform requires to be present in its input
TensorDicts. The output keys are the set of keys that will be present in the output
TensorDicts of the transform.

Since the computation of the required and output keys and the verification that the
transform is valid are sometimes intertwined operations, we do them in a single method.
"""

__lshift__ = compose
__or__ = conjunct


class Composition(Transform[_A, _C]):
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
if outer.required_keys != inner.output_keys:
raise ValueError(
"The `output_keys` of `inner` must match with the `required_keys` of "
f"outer. Found {outer.required_keys} and {inner.output_keys}"
)
self.outer = outer
self.inner = inner

def __str__(self) -> str:
return str(self.outer) + " ∘ " + str(self.inner)

def _compute(self, input: _A) -> _C:
def __call__(self, input: _A) -> _C:
intermediate = self.inner(input)
return self.outer(intermediate)

@property
def required_keys(self) -> set[Tensor]:
return self.inner.required_keys

@property
def output_keys(self) -> set[Tensor]:
return self.outer.output_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
outer_required_keys, outer_output_keys = self.outer.check_and_get_keys()
inner_required_keys, inner_output_keys = self.inner.check_and_get_keys()
if outer_required_keys != inner_output_keys:
raise ValueError(
"The `output_keys` of `inner` must match with the `required_keys` of "
f"outer. Found {outer_required_keys} and {inner_output_keys}"
)
return inner_required_keys, outer_output_keys


class Conjunction(Transform[_A, _B]):
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
self.transforms = transforms

self._required_keys = set(
key for transform in transforms for key in transform.required_keys
)
for transform in transforms:
if transform.required_keys != self.required_keys:
raise ValueError("All transforms should require the same set of keys.")

output_keys_with_duplicates = [key for t in transforms for key in t.output_keys]
self._output_keys = set(output_keys_with_duplicates)

if len(self._output_keys) != len(output_keys_with_duplicates):
raise ValueError("The sets of output keys of transforms should be disjoint.")

def __str__(self) -> str:
strings = []
for t in self.transforms:
Expand All @@ -116,14 +98,22 @@ def __str__(self) -> str:
strings.append(s)
return "(" + " | ".join(strings) + ")"

def _compute(self, tensor_dict: _A) -> _B:
def __call__(self, tensor_dict: _A) -> _B:
output = _union([transform(tensor_dict) for transform in self.transforms])
return output

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]

required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
for transform_required_keys, _ in keys_pairs:
if transform_required_keys != required_keys:
raise ValueError("All transforms should require the same set of keys.")

output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys]
output_keys = set(output_keys_with_duplicates)

if len(output_keys) != len(output_keys_with_duplicates):
raise ValueError("The sets of output keys of transforms should be disjoint.")

@property
def output_keys(self) -> set[Tensor]:
return self._output_keys
return required_keys, output_keys
12 changes: 4 additions & 8 deletions src/torchjd/autojac/_transform/diagonalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, considered: Iterable[Tensor]):
self.indices.append((begin, end))
begin = end

def _compute(self, tensors: Gradients) -> Jacobians:
def __call__(self, tensors: Gradients) -> Jacobians:
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
diagonal_matrix = torch.cat(flattened_considered_values).diag()
diagonalized_tensors = {
Expand All @@ -27,10 +27,6 @@ def _compute(self, tensors: Gradients) -> Jacobians:
}
return Jacobians(diagonalized_tensors)

@property
def required_keys(self) -> set[Tensor]:
return set(self.considered)

@property
def output_keys(self) -> set[Tensor]:
return set(self.considered)
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys = set(self.considered)
return keys, keys
11 changes: 3 additions & 8 deletions src/torchjd/autojac/_transform/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Init(Transform[EmptyTensorDict, Gradients]):
def __init__(self, values: Iterable[Tensor]):
self.values = set(values)

def _compute(self, input: EmptyTensorDict) -> Gradients:
def __call__(self, input: EmptyTensorDict) -> Gradients:
r"""
Computes the gradients of the ``value`` with respect to itself. Returns the result as a
dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
Expand All @@ -21,10 +21,5 @@ def _compute(self, input: EmptyTensorDict) -> Gradients:

return Gradients({value: torch.ones_like(value) for value in self.values})

@property
def required_keys(self) -> set[Tensor]:
return set()

@property
def output_keys(self) -> set[Tensor]:
return self.values
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return set(), self.values
16 changes: 6 additions & 10 deletions src/torchjd/autojac/_transform/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@ def __init__(self, keys: Iterable[Tensor], required_keys: Iterable[Tensor]):
self.keys = set(keys)
self._required_keys = set(required_keys)

if not self.keys.issubset(self._required_keys):
raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`")

def _compute(self, tensor_dict: _A) -> _A:
def __call__(self, tensor_dict: _A) -> _A:
output = {key: tensor_dict[key] for key in self.keys}
return type(tensor_dict)(output)

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
required_keys = self._required_keys
if not self.keys.issubset(required_keys):
raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`")

@property
def output_keys(self) -> set[Tensor]:
return self.keys
return required_keys, self.keys
25 changes: 11 additions & 14 deletions src/torchjd/autojac/_transform/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,22 @@ class Stack(Transform[_A, Jacobians]):
def __init__(self, transforms: Sequence[Transform[_A, Gradients]]):
self.transforms = transforms

self._required_keys = {key for transform in transforms for key in transform.required_keys}
self._output_keys = {key for transform in transforms for key in transform.output_keys}

for transform in transforms:
if transform.required_keys != self.required_keys:
raise ValueError("All transforms should require the same set of keys.")

def _compute(self, input: _A) -> Jacobians:
def __call__(self, input: _A) -> Jacobians:
results = [transform(input) for transform in self.transforms]
result = _stack(results)
return result

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]

required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
output_keys = set(key for _, output_keys in keys_pairs for key in output_keys)

for transform_required_keys, _ in keys_pairs:
if transform_required_keys != required_keys:
raise ValueError("All transforms should require the same set of keys.")

@property
def output_keys(self) -> set[Tensor]:
return self._output_keys
return required_keys, output_keys


def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
Expand Down
13 changes: 0 additions & 13 deletions src/torchjd/autojac/_transform/tensor_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,6 @@ def __init__(self, tensor_dict: dict[Tensor, Tensor]):
self._check_all_pairs(tensor_dict)
super().__init__(tensor_dict)

def check_keys_are(self, keys: set[Tensor]) -> None:
"""
Checks that the keys in the mapping are the same as the provided ``keys``.

:param keys: Keys that the mapping should (exclusively) contain.
"""

if set(keys) != set(self.keys()):
raise ValueError(
f"The keys of the {self.__class__.__name__} should be {keys}. Found self.keys = "
f"{self.keys()}."
)

@staticmethod
def _check_dict(tensor_dict: dict[Tensor, Tensor]) -> None:
pass
Expand Down
Loading