Skip to content

Commit 481d83f

Browse files
authored
feat(autogram): Remove batched optimizations (#470)
* Remove batch_dim parameter from Engine * Remove test_batched_non_batched_equivalence and test_batched_non_batched_equivalence_2 * Adapt all usages of Engine to not provide batch_dim * Remove FunctionalJacobianComputer * Remove args and kwargs from interface of JacobianComputer, GramianComputer and JacobianAccumulator because they were only needed for the functional interface * Remove kwargs from interface of Hook and stop registering it with with_kwargs=True (args are mandatory though, so rename them as _). * Change JacobianComputer to compute generalized jacobians (shape [m0, ..., mk, n]) and change GramianComputer to compute optional generalized gramians (shape [m0, ..., mk, mk, ..., m0]) * Change engine.compute_gramian to always simply do one vmap level per dimension of the output, without caring about the batch_dim. * Remove all reshapes and movedims in engine.compute_gramian: we don't need reshape anymore since the gramian is directly a generalized gramian, and we dont need movedim anymore since we vmap on all dimensions the same way, without having to put the non-batched dim in front. Merge compute_gramian and _compute_square_gramian. * Add temporary function to create the inital jac_output (dense). This should be updated to a tensor format that is optimized for batched computations.
1 parent 878bf9b commit 481d83f

File tree

11 files changed

+125
-500
lines changed

11 files changed

+125
-500
lines changed

docs/source/examples/iwmtl.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The following example shows how to do that.
3131
optimizer = SGD(params, lr=0.1)
3232
mse = MSELoss(reduction="none")
3333
weighting = Flattening(UPGradWeighting())
34-
engine = Engine(shared_module, batch_dim=0)
34+
engine = Engine(shared_module)
3535
3636
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
3737
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task

docs/source/examples/iwrm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
129129
params = model.parameters()
130130
optimizer = SGD(params, lr=0.1)
131131
weighting = UPGradWeighting()
132-
engine = Engine(model, batch_dim=0)
132+
engine = Engine(model)
133133
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]

docs/source/examples/partial_jd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
3333
3434
# Create the autogram engine that will compute the Gramian of the
3535
# Jacobian with respect to the two last Linear layers' parameters.
36-
engine = Engine(model[2:], batch_dim=0)
36+
engine = Engine(model[2:])
3737
3838
params = model.parameters()
3939
optimizer = SGD(params, lr=0.1)

src/torchjd/autogram/_engine.py

Lines changed: 41 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,9 @@
77
from ._edge_registry import EdgeRegistry
88
from ._gramian_accumulator import GramianAccumulator
99
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
10-
from ._gramian_utils import movedim_gramian, reshape_gramian
11-
from ._jacobian_computer import (
12-
AutogradJacobianComputer,
13-
FunctionalJacobianComputer,
14-
JacobianComputer,
15-
)
10+
from ._jacobian_computer import AutogradJacobianComputer
1611
from ._module_hook_manager import ModuleHookManager
1712

18-
_MODULES_INCOMPATIBLE_WITH_BATCHED = (
19-
nn.BatchNorm1d,
20-
nn.BatchNorm2d,
21-
nn.BatchNorm3d,
22-
nn.LazyBatchNorm1d,
23-
nn.LazyBatchNorm2d,
24-
nn.LazyBatchNorm3d,
25-
nn.SyncBatchNorm,
26-
nn.RNNBase,
27-
)
28-
29-
_TRACK_RUNNING_STATS_MODULE_TYPES = (
30-
nn.BatchNorm1d,
31-
nn.BatchNorm2d,
32-
nn.BatchNorm3d,
33-
nn.LazyBatchNorm1d,
34-
nn.LazyBatchNorm2d,
35-
nn.LazyBatchNorm3d,
36-
nn.SyncBatchNorm,
37-
nn.InstanceNorm1d,
38-
nn.InstanceNorm2d,
39-
nn.InstanceNorm3d,
40-
nn.LazyInstanceNorm1d,
41-
nn.LazyInstanceNorm2d,
42-
nn.LazyInstanceNorm3d,
43-
)
44-
4513

4614
class Engine:
4715
"""
@@ -50,7 +18,7 @@ class Engine:
5018
Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_ but goes even further:
5119
5220
* It works for any computation graph (not just sequential models).
53-
* It is optimized for batched computations (as long as ``batch_dim`` is specified).
21+
* It is highly optimized for batched computations but also supports non-batched computations.
5422
* It supports any shape of tensor to differentiate (not just a vector of losses). For more
5523
details about this, look at :meth:`Engine.compute_gramian`.
5624
@@ -66,10 +34,6 @@ class Engine:
6634
:param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
6735
Several modules can be provided, but it's important that none of them is a child module of
6836
another of them.
69-
:param batch_dim: If the modules work with batches and process each batch element independently,
70-
then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
71-
memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
72-
the batch dimension of the output tensor, if any.
7337
7438
.. admonition::
7539
Example
@@ -97,7 +61,7 @@ class Engine:
9761
weighting = UPGradWeighting()
9862
9963
# Create the engine before the backward pass, and only once.
100-
engine = Engine(model, batch_dim=0)
64+
engine = Engine(model)
10165
10266
for input, target in zip(inputs, targets):
10367
output = model(input).squeeze(dim=1) # shape: [16]
@@ -113,48 +77,13 @@ class Engine:
11377
since the Jacobian never has to be entirely in memory, it is often much more
11478
memory-efficient, and thus typically faster, to use the Gramian-based approach.
11579
80+
.. warning:: For autogram to be fast and low-memory, it is very important to use only batched
81+
modules (i.e. modules that treat each element of the batch independently). For instance,
82+
BatchNorm is not a batched module because it computes some statistics over the batch.
83+
11684
.. warning::
117-
When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:
118-
119-
* They should treat the elements of the batch independently. Most common layers respect
120-
this, but for example `BatchNorm
121-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
122-
computes some average and standard deviation over the elements of the batch).
123-
* Their inputs and outputs can be anything, but each input tensor and each output tensor
124-
must be batched on its first dimension. When available (e.g. in `Transformers
125-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
126-
`MultiheadAttention
127-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
128-
etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
129-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
130-
because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
131-
* They should not perform in-place operations on tensors (for instance you should not use
132-
``track_running_stats=True`` in normalization layers).
133-
* They should not have side effects during the forward pass (since their forward pass will
134-
be called twice, the side effects could be different from what's expected).
135-
* If they have some randomness during the forward pass, they should not have direct
136-
trainable parameters. For this reason,
137-
`Transformers
138-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
139-
dropout function (rather than a `Dropout
140-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
141-
module with some trainable parameters, has to be used with
142-
``dropout=0.0``. Note that a `Dropout
143-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
144-
entirely supported and should be preferred. It is also perfectly fine for random modules
145-
to have child modules that have trainable parameters, so if you have a random module with
146-
some direct parameters, a simple fix is to wrap these parameters into a child module.
147-
148-
If you're building your own architecture, respecting those criteria should be quite easy.
149-
However, if you're using an existing architecture, you may have to modify it to make it
150-
compatible with the autogram engine. For instance, you may want to replace `BatchNorm2d
151-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ layers by
152-
`GroupNorm <https://docs.pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html>`_ or
153-
`InstanceNorm2d
154-
<https://docs.pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html>`_ layers.
155-
156-
The alternative is to use ``batch_dim=None``, but it's not recommended since it will
157-
increase memory usage by a lot and thus typically slow down computation.
85+
`RNNs <https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ may not be
86+
supported on cuda because vmap is not implemented for RNN on that device.
15887
15988
.. warning::
16089
Parent modules should call their child modules directly rather than using their child
@@ -177,14 +106,9 @@ class Engine:
177106
another child module to avoid the slow-down.
178107
"""
179108

180-
def __init__(
181-
self,
182-
*modules: nn.Module,
183-
batch_dim: int | None,
184-
):
109+
def __init__(self, *modules: nn.Module):
185110
self._gramian_accumulator = GramianAccumulator()
186111
self._target_edges = EdgeRegistry()
187-
self._batch_dim = batch_dim
188112
self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator)
189113
self._gramian_computers = dict[nn.Module, GramianComputer]()
190114

@@ -193,7 +117,6 @@ def __init__(
193117

194118
def _hook_module_recursively(self, module: nn.Module) -> None:
195119
if any(p.requires_grad for p in module.parameters(recurse=False)):
196-
self._check_module_is_compatible(module)
197120
gramian_computer = self._make_gramian_computer(module)
198121
self._gramian_computers[module] = gramian_computer
199122
self._module_hook_manager.hook_module(module, gramian_computer)
@@ -202,36 +125,11 @@ def _hook_module_recursively(self, module: nn.Module) -> None:
202125
self._hook_module_recursively(child)
203126

204127
def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
205-
jacobian_computer: JacobianComputer
206-
if self._batch_dim is not None:
207-
jacobian_computer = FunctionalJacobianComputer(module)
208-
else:
209-
jacobian_computer = AutogradJacobianComputer(module)
128+
jacobian_computer = AutogradJacobianComputer(module)
210129
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
211130

212131
return gramian_computer
213132

214-
def _check_module_is_compatible(self, module: nn.Module) -> None:
215-
if self._batch_dim is not None:
216-
if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED):
217-
raise ValueError(
218-
f"Found a module of type {type(module)}, which is incompatible with the "
219-
f"autogram engine when `batch_dim` is not `None`. The incompatible module types"
220-
f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The "
221-
f"recommended fix is to replace incompatible layers by something else (e.g. "
222-
f"BatchNorm by InstanceNorm). If you really can't and performance is not a "
223-
f"priority, you may also just set `batch_dim=None` when creating the engine."
224-
)
225-
if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats:
226-
raise ValueError(
227-
f"Found a module of type {type(module)}, with `track_running_stats=True`, which"
228-
f" is incompatible with the autogram engine when `batch_dim` is not `None`, due"
229-
f" to performing in-place operations on tensors and having side-effects during "
230-
f"the forward pass. Try setting `track_running_stats` to `False`. If you really"
231-
f" can't and performance is not a priority, you may also just set "
232-
f"`batch_dim=None` when creating the engine."
233-
)
234-
235133
def compute_gramian(self, output: Tensor) -> Tensor:
236134
r"""
237135
Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
@@ -261,33 +159,31 @@ def compute_gramian(self, output: Tensor) -> Tensor:
261159
- etc.
262160
"""
263161

264-
if self._batch_dim is not None:
265-
# move batched dim to the end
266-
ordered_output = output.movedim(self._batch_dim, -1)
267-
ordered_shape = list(ordered_output.shape)
268-
batch_size = ordered_shape[-1]
269-
has_non_batch_dim = len(ordered_shape) > 1
270-
target_shape = [batch_size]
271-
else:
272-
ordered_output = output
273-
ordered_shape = list(ordered_output.shape)
274-
has_non_batch_dim = len(ordered_shape) > 0
275-
target_shape = []
162+
self._module_hook_manager.gramian_accumulation_phase.value = True
163+
164+
try:
165+
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
166+
167+
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
168+
return torch.autograd.grad(
169+
outputs=output,
170+
inputs=leaf_targets,
171+
grad_outputs=_grad_output,
172+
retain_graph=True,
173+
)
276174

277-
if has_non_batch_dim:
278-
target_shape = [-1] + target_shape
175+
output_dims = list(range(output.ndim))
176+
jac_output = _make_initial_jac_output(output)
279177

280-
reshaped_output = ordered_output.reshape(target_shape)
281-
# There are four different cases for the shape of reshaped_output:
282-
# - Not batched and not non-batched: scalar of shape []
283-
# - Batched only: vector of shape [batch_size]
284-
# - Non-batched only: vector of shape [dim]
285-
# - Batched and non-batched: matrix of shape [dim, batch_size]
178+
vmapped_diff = differentiation
179+
for _ in output_dims:
180+
vmapped_diff = vmap(vmapped_diff)
286181

287-
self._module_hook_manager.gramian_accumulation_phase.value = True
182+
_ = vmapped_diff(jac_output)
288183

289-
try:
290-
square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim)
184+
# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
185+
# have failed. So gramian is necessarily a valid Tensor here.
186+
gramian = cast(Tensor, self._gramian_accumulator.gramian)
291187
finally:
292188
# Reset everything that has a state, even if the previous call raised an exception
293189
self._module_hook_manager.gramian_accumulation_phase.value = False
@@ -296,40 +192,16 @@ def compute_gramian(self, output: Tensor) -> Tensor:
296192
for gramian_computer in self._gramian_computers.values():
297193
gramian_computer.reset()
298194

299-
unordered_gramian = reshape_gramian(square_gramian, ordered_shape)
300-
301-
if self._batch_dim is not None:
302-
gramian = movedim_gramian(unordered_gramian, [-1], [self._batch_dim])
303-
else:
304-
gramian = unordered_gramian
305-
306195
return gramian
307196

308-
def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor:
309-
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
310-
311-
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
312-
return torch.autograd.grad(
313-
outputs=output,
314-
inputs=leaf_targets,
315-
grad_outputs=_grad_output,
316-
retain_graph=True,
317-
)
318-
319-
if has_non_batch_dim:
320-
# There is one non-batched dimension, it is the first one
321-
non_batch_dim_len = output.shape[0]
322-
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
323-
ones = torch.ones_like(output[0])
324-
jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
325-
326-
_ = vmap(differentiation)(jac_output)
327-
else:
328-
grad_output = torch.ones_like(output)
329-
_ = differentiation(grad_output)
330197

331-
# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
332-
# have failed. So gramian is necessarily a valid Tensor here.
333-
gramian = cast(Tensor, self._gramian_accumulator.gramian)
198+
def _make_initial_jac_output(output: Tensor) -> Tensor:
199+
if output.ndim == 0:
200+
return torch.ones_like(output)
201+
p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape]
202+
p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij")
203+
v_indices_grid = p_indices_grid + p_indices_grid
334204

335-
return gramian
205+
res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype)
206+
res[v_indices_grid] = 1.0
207+
return res

src/torchjd/autogram/_gramian_computer.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from abc import ABC, abstractmethod
22
from typing import Optional
33

4+
import torch
45
from torch import Tensor
5-
from torch.utils._pytree import PyTree
66

77
from torchjd.autogram._jacobian_computer import JacobianComputer
88

@@ -13,8 +13,6 @@ def __call__(
1313
self,
1414
rg_outputs: tuple[Tensor, ...],
1515
grad_outputs: tuple[Tensor, ...],
16-
args: tuple[PyTree, ...],
17-
kwargs: dict[str, PyTree],
1816
) -> Optional[Tensor]:
1917
"""Compute what we can for a module and optionally return the gramian if it's ready."""
2018

@@ -30,8 +28,12 @@ def __init__(self, jacobian_computer):
3028
self.jacobian_computer = jacobian_computer
3129

3230
@staticmethod
33-
def _to_gramian(jacobian: Tensor) -> Tensor:
34-
return jacobian @ jacobian.T
31+
def _to_gramian(matrix: Tensor) -> Tensor:
32+
"""Contracts the last dimension of matrix to make it into a Gramian."""
33+
34+
indices = list(range(matrix.ndim))
35+
transposed_matrix = matrix.movedim(indices, indices[::-1])
36+
return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0]))
3537

3638

3739
class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
@@ -53,20 +55,16 @@ def track_forward_call(self) -> None:
5355
self.remaining_counter += 1
5456

5557
def __call__(
56-
self,
57-
rg_outputs: tuple[Tensor, ...],
58-
grad_outputs: tuple[Tensor, ...],
59-
args: tuple[PyTree, ...],
60-
kwargs: dict[str, PyTree],
58+
self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]
6159
) -> Optional[Tensor]:
6260
"""Compute what we can for a module and optionally return the gramian if it's ready."""
6361

64-
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
62+
jacobian = self.jacobian_computer(rg_outputs, grad_outputs)
6563

6664
if self.summed_jacobian is None:
67-
self.summed_jacobian = jacobian_matrix
65+
self.summed_jacobian = jacobian
6866
else:
69-
self.summed_jacobian += jacobian_matrix
67+
self.summed_jacobian += jacobian
7068

7169
self.remaining_counter -= 1
7270

0 commit comments

Comments
 (0)