Skip to content

Commit 6d7d140

Browse files
authored
feat(autogram): Add kwargs support (#443)
* Add kwargs support * Add WithModuleWithStringKwarg test * Add WithModuleWithHybridPyTreeKwarg test * Factorize existing architectures * Fix some typos
1 parent 6aae58c commit 6d7d140

File tree

4 files changed

+120
-54
lines changed

4 files changed

+120
-54
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def hook_module(self, module: nn.Module) -> None:
6565
self._gramian_accumulator,
6666
self._has_batch_dim,
6767
)
68-
self._handles.append(module.register_forward_hook(hook))
68+
self._handles.append(module.register_forward_hook(hook, with_kwargs=True))
6969

7070
@staticmethod
7171
def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
@@ -101,7 +101,13 @@ def __init__(
101101
self.gramian_accumulator = gramian_accumulator
102102
self.has_batch_dim = has_batch_dim
103103

104-
def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree:
104+
def __call__(
105+
self,
106+
module: nn.Module,
107+
args: tuple[PyTree, ...],
108+
kwargs: dict[str, PyTree],
109+
outputs: PyTree,
110+
) -> PyTree:
105111
if self.gramian_accumulation_phase:
106112
return outputs
107113

@@ -131,9 +137,10 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree)
131137

132138
vjp: VJP
133139
if self.has_batch_dim:
134-
rg_outputs_in_dims = (0,) * len(rg_outputs)
135-
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
136-
in_dims = (rg_outputs_in_dims, args_in_dims)
140+
rg_output_in_dims = (0,) * len(rg_outputs)
141+
arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
142+
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
143+
in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims)
137144
vjp = FunctionalVJP(module, in_dims)
138145
else:
139146
vjp = AutogradVJP(module, rg_outputs)
@@ -142,6 +149,7 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree)
142149
self.gramian_accumulation_phase,
143150
vjp,
144151
args,
152+
kwargs,
145153
self.gramian_accumulator,
146154
module,
147155
*rg_outputs,
@@ -169,14 +177,15 @@ def forward(
169177
gramian_accumulation_phase: BoolRef,
170178
vjp: VJP,
171179
args: tuple[PyTree, ...],
180+
kwargs: dict[str, PyTree],
172181
gramian_accumulator: GramianAccumulator,
173182
module: nn.Module,
174183
*rg_tensors: Tensor,
175184
) -> tuple[Tensor, ...]:
176185
return tuple(t.detach() for t in rg_tensors)
177186

178187
# For Python version > 3.10, the type of `inputs` should become
179-
# tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
188+
# tuple[BoolRef, VJP, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
180189
@staticmethod
181190
def setup_context(
182191
ctx,
@@ -186,25 +195,27 @@ def setup_context(
186195
ctx.gramian_accumulation_phase = inputs[0]
187196
ctx.vjp = inputs[1]
188197
ctx.args = inputs[2]
189-
ctx.gramian_accumulator = inputs[3]
190-
ctx.module = inputs[4]
198+
ctx.kwargs = inputs[3]
199+
ctx.gramian_accumulator = inputs[4]
200+
ctx.module = inputs[5]
191201

192202
@staticmethod
193203
def backward(ctx, *grad_outputs: Tensor) -> tuple:
194-
# Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]]
204+
# For python > 3.10: -> tuple[None, None, None, None, None, None, *tuple[Tensor, ...]]
195205

196206
if not ctx.gramian_accumulation_phase:
197-
return None, None, None, None, None, *grad_outputs
207+
return None, None, None, None, None, None, *grad_outputs
198208

199209
AccumulateJacobian.apply(
200210
ctx.vjp,
201211
ctx.args,
212+
ctx.kwargs,
202213
ctx.gramian_accumulator,
203214
ctx.module,
204215
*grad_outputs,
205216
)
206217

207-
return None, None, None, None, None, *grad_outputs
218+
return None, None, None, None, None, None, *grad_outputs
208219

209220

210221
class AccumulateJacobian(torch.autograd.Function):
@@ -213,29 +224,31 @@ class AccumulateJacobian(torch.autograd.Function):
213224
def forward(
214225
vjp: VJP,
215226
args: tuple[PyTree, ...],
227+
kwargs: dict[str, PyTree],
216228
gramian_accumulator: GramianAccumulator,
217229
module: nn.Module,
218230
*grad_outputs: Tensor,
219231
) -> None:
220232
# There is no non-batched dimension
221-
generalized_jacobians = vjp(grad_outputs, args)
233+
generalized_jacobians = vjp(grad_outputs, args, kwargs)
222234
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
223235
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
224236

225237
@staticmethod
226238
def vmap(
227239
_,
228-
in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]]
240+
in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]]
229241
vjp: VJP,
230242
args: tuple[PyTree, ...],
243+
kwargs: dict[str, PyTree],
231244
gramian_accumulator: GramianAccumulator,
232245
module: nn.Module,
233246
*jac_outputs: Tensor,
234247
) -> tuple[None, None]:
235248
# There is a non-batched dimension
236249
# We do not vmap over the args for the non-batched dimension
237-
in_dims = (in_dims[4:], tree_map(lambda _: None, args))
238-
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
250+
in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs))
251+
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs)
239252
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
240253
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
241254
return None, None

src/torchjd/autogram/_vjp.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class VJP(ABC):
2020

2121
@abstractmethod
2222
def __call__(
23-
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
23+
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree]
2424
) -> dict[str, Tensor]:
2525
"""
2626
Computes and returns the dictionary of parameter names to their gradients for the given
@@ -59,19 +59,23 @@ def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]):
5959
self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)
6060

6161
def __call__(
62-
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
62+
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree]
6363
) -> dict[str, Tensor]:
64-
return self.vmapped_vjp(grad_outputs, args)
64+
return self.vmapped_vjp(grad_outputs, args, kwargs)
6565

6666
def _call_on_one_instance(
67-
self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...]
67+
self,
68+
grad_outputs_j: tuple[Tensor, ...],
69+
args_j: tuple[PyTree, ...],
70+
kwargs_j: dict[str, PyTree],
6871
) -> dict[str, Tensor]:
6972
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
7073
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
7174
# nn.Flatten) do not work equivalently if they're provided with a batch or with
7275
# an element of a batch. We thus always provide them with batches, just of a
7376
# different size.
7477
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
78+
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
7579
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]
7680

7781
def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
@@ -80,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor
8084
**dict(self.module.named_buffers()),
8185
**self.frozen_params,
8286
}
83-
output = torch.func.functional_call(self.module, all_state, args_j)
87+
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
8488
flat_outputs = tree_flatten(output)[0]
8589
rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad]
8690
return rg_outputs
@@ -108,7 +112,7 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
108112
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)
109113

110114
def __call__(
111-
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...]
115+
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree]
112116
) -> dict[str, Tensor]:
113117
grads = torch.autograd.grad(
114118
self.rg_outputs,

tests/unit/autogram/test_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
WithDropout,
5353
WithModuleTrackingRunningStats,
5454
WithModuleWithHybridPyTreeArg,
55+
WithModuleWithHybridPyTreeKwarg,
5556
WithModuleWithStringArg,
57+
WithModuleWithStringKwarg,
5658
WithModuleWithStringOutput,
5759
WithNoTensorOutput,
5860
WithRNN,
@@ -113,6 +115,8 @@
113115
(WithModuleWithStringArg, 32),
114116
(WithModuleWithHybridPyTreeArg, 32),
115117
(WithModuleWithStringOutput, 32),
118+
(WithModuleWithStringKwarg, 32),
119+
(WithModuleWithHybridPyTreeKwarg, 32),
116120
(FreeParam, 32),
117121
(NoFreeParam, 32),
118122
param(Cifar10Model, 16, marks=mark.slow),

tests/utils/architectures.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -772,29 +772,66 @@ def forward(self, input: Tensor) -> Tensor:
772772
return input @ self.linear.weight.T + self.linear.bias
773773

774774

775+
class _WithStringArg(nn.Module):
776+
def __init__(self):
777+
super().__init__()
778+
self.matrix = nn.Parameter(torch.randn(2, 3))
779+
780+
def forward(self, s: str, input: Tensor) -> Tensor:
781+
if s == "two":
782+
return input @ self.matrix * 2.0
783+
else:
784+
return input @ self.matrix
785+
786+
775787
class WithModuleWithStringArg(ShapedModule):
776788
"""Model containing a module that has a string argument."""
777789

778790
INPUT_SHAPES = (2,)
779791
OUTPUT_SHAPES = (3,)
780792

781-
class WithStringArg(nn.Module):
782-
def __init__(self):
783-
super().__init__()
784-
self.matrix = nn.Parameter(torch.randn(2, 3))
793+
def __init__(self):
794+
super().__init__()
795+
self.with_string_arg = _WithStringArg()
785796

786-
def forward(self, s: str, input: Tensor) -> Tensor:
787-
if s == "two":
788-
return input @ self.matrix * 2.0
789-
else:
790-
return input @ self.matrix
797+
def forward(self, input: Tensor) -> Tensor:
798+
return self.with_string_arg("two", input)
799+
800+
801+
class WithModuleWithStringKwarg(ShapedModule):
802+
"""Model calling its submodule's forward with a string and a tensor as keyword arguments."""
803+
804+
INPUT_SHAPES = (2,)
805+
OUTPUT_SHAPES = (3,)
791806

792807
def __init__(self):
793808
super().__init__()
794-
self.with_string_arg = self.WithStringArg()
809+
self.with_string_arg = _WithStringArg()
795810

796811
def forward(self, input: Tensor) -> Tensor:
797-
return self.with_string_arg("two", input)
812+
return self.with_string_arg(s="two", input=input)
813+
814+
815+
class _WithHybridPyTreeArg(nn.Module):
816+
def __init__(self):
817+
super().__init__()
818+
self.m0 = nn.Parameter(torch.randn(3, 3))
819+
self.m1 = nn.Parameter(torch.randn(4, 3))
820+
self.m2 = nn.Parameter(torch.randn(5, 3))
821+
self.m3 = nn.Parameter(torch.randn(6, 3))
822+
823+
def forward(self, input: PyTree) -> Tensor:
824+
t0 = input["one"][0][0]
825+
t1 = input["one"][0][1]
826+
t2 = input["one"][1]
827+
t3 = input["two"]
828+
829+
c0 = input["one"][0][3]
830+
c1 = input["one"][0][4][0]
831+
c2 = input["one"][2]
832+
c3 = input["three"]
833+
834+
return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3
798835

799836

800837
class WithModuleWithHybridPyTreeArg(ShapedModule):
@@ -806,31 +843,39 @@ class WithModuleWithHybridPyTreeArg(ShapedModule):
806843
INPUT_SHAPES = (10,)
807844
OUTPUT_SHAPES = (3,)
808845

809-
class WithHybridPyTreeArg(nn.Module):
810-
def __init__(self):
811-
super().__init__()
812-
self.m0 = nn.Parameter(torch.randn(3, 3))
813-
self.m1 = nn.Parameter(torch.randn(4, 3))
814-
self.m2 = nn.Parameter(torch.randn(5, 3))
815-
self.m3 = nn.Parameter(torch.randn(6, 3))
846+
def __init__(self):
847+
super().__init__()
848+
self.linear = nn.Linear(10, 18)
849+
self.with_string_arg = _WithHybridPyTreeArg()
816850

817-
def forward(self, input: PyTree) -> Tensor:
818-
t0 = input["one"][0][0]
819-
t1 = input["one"][0][1]
820-
t2 = input["one"][1]
821-
t3 = input["two"]
851+
def forward(self, input: Tensor) -> Tensor:
852+
input = self.linear(input)
822853

823-
c0 = input["one"][0][3]
824-
c1 = input["one"][0][4][0]
825-
c2 = input["one"][2]
826-
c3 = input["three"]
854+
t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18]
827855

828-
return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3
856+
tree = {
857+
"zero": "unused",
858+
"one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"],
859+
"two": t3,
860+
"three": 0.5,
861+
}
862+
863+
return self.with_string_arg(tree)
864+
865+
866+
class WithModuleWithHybridPyTreeKwarg(ShapedModule):
867+
"""
868+
Model calling its submodule's forward with a PyTree keyword argument containing a mix of tensors
869+
and non-tensor values.
870+
"""
871+
872+
INPUT_SHAPES = (10,)
873+
OUTPUT_SHAPES = (3,)
829874

830875
def __init__(self):
831876
super().__init__()
832877
self.linear = nn.Linear(10, 18)
833-
self.with_string_arg = self.WithHybridPyTreeArg()
878+
self.with_string_arg = _WithHybridPyTreeArg()
834879

835880
def forward(self, input: Tensor) -> Tensor:
836881
input = self.linear(input)
@@ -844,7 +889,7 @@ def forward(self, input: Tensor) -> Tensor:
844889
"three": 0.5,
845890
}
846891

847-
return self.with_string_arg(tree)
892+
return self.with_string_arg(input=tree)
848893

849894

850895
class WithModuleWithStringOutput(ShapedModule):
@@ -1067,11 +1112,11 @@ def forward(self, input: Tensor) -> Tensor:
10671112

10681113

10691114
# Other torchvision.models were not added for the following reasons:
1070-
# - VGG16: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok.
1115+
# - VGG16: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok.
10711116
# - DenseNet: no way to easily replace the BatchNorms (no norm_layer param)
10721117
# - InceptionV3: no way to easily replace the BatchNorms (no norm_layer param)
10731118
# - GoogleNet: no way to easily replace the BatchNorms (no norm_layer param)
10741119
# - ShuffleNetV2: no way to easily replace the BatchNorms (no norm_layer param)
1075-
# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok.
1076-
# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok.
1120+
# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok.
1121+
# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok.
10771122
# - MNASNet: no way to easily replace the BatchNorms (no norm_layer param)

0 commit comments

Comments
 (0)