Skip to content

Commit 1b4ac0e

Browse files
committed
Make Accumulate not check required_keys
1 parent c4778a0 commit 1b4ac0e

3 files changed

Lines changed: 4 additions & 14 deletions

File tree

src/torchjd/autojac/_transform/accumulate.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-
from typing import Iterable
2-
31
from torch import Tensor
42

5-
from .base import RequirementError, Transform
3+
from .base import Transform
64
from .tensor_dict import EmptyTensorDict, Gradients
75

86

97
class Accumulate(Transform[Gradients, EmptyTensorDict]):
10-
def __init__(self, required_keys: Iterable[Tensor]):
11-
self._required_keys = set(required_keys)
12-
138
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
149
"""
1510
Accumulates gradients with respect to keys in their ``.grad`` field.
@@ -29,11 +24,6 @@ def __call__(self, gradients: Gradients) -> EmptyTensorDict:
2924
return EmptyTensorDict()
3025

3126
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
32-
if not self._required_keys.issubset(input_keys):
33-
raise RequirementError(
34-
f"The input_keys needs to be a super set of the required_keys. Found {input_keys} "
35-
f"and {self._required_keys}"
36-
)
3727
return set()
3828

3929

src/torchjd/autojac/backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,6 @@ def _create_transform(
112112
aggregate = Aggregate(aggregator, inputs)
113113

114114
# Transform that accumulates the result in the .grad field of the inputs.
115-
accumulate = Accumulate(inputs)
115+
accumulate = Accumulate()
116116

117117
return accumulate << aggregate << jac << diag << init

src/torchjd/autojac/mtl_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _create_transform(
155155
aggregate = Aggregate(aggregator, shared_params)
156156

157157
# Transform that accumulates the result in the .grad field of the shared parameters.
158-
accumulate = Accumulate(shared_params)
158+
accumulate = Accumulate()
159159

160160
return accumulate << aggregate << jac << stack
161161

@@ -179,7 +179,7 @@ def _create_task_transform(
179179

180180
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
181181
# .grad fields.
182-
accumulate = Accumulate(task_params) << Select(task_params)
182+
accumulate = Accumulate() << Select(task_params)
183183

184184
# Transform that backpropagates the gradients of the losses w.r.t. the features.
185185
backpropagate = Select(features)

0 commit comments

Comments
 (0)