77from ._edge_registry import EdgeRegistry
88from ._gramian_accumulator import GramianAccumulator
99from ._gramian_computer import GramianComputer , JacobianBasedGramianComputerWithCrossTerms
10- from ._gramian_utils import movedim_gramian , reshape_gramian
11- from ._jacobian_computer import (
12- AutogradJacobianComputer ,
13- FunctionalJacobianComputer ,
14- JacobianComputer ,
15- )
10+ from ._jacobian_computer import AutogradJacobianComputer
1611from ._module_hook_manager import ModuleHookManager
1712
18- _MODULES_INCOMPATIBLE_WITH_BATCHED = (
19- nn .BatchNorm1d ,
20- nn .BatchNorm2d ,
21- nn .BatchNorm3d ,
22- nn .LazyBatchNorm1d ,
23- nn .LazyBatchNorm2d ,
24- nn .LazyBatchNorm3d ,
25- nn .SyncBatchNorm ,
26- nn .RNNBase ,
27- )
28-
29- _TRACK_RUNNING_STATS_MODULE_TYPES = (
30- nn .BatchNorm1d ,
31- nn .BatchNorm2d ,
32- nn .BatchNorm3d ,
33- nn .LazyBatchNorm1d ,
34- nn .LazyBatchNorm2d ,
35- nn .LazyBatchNorm3d ,
36- nn .SyncBatchNorm ,
37- nn .InstanceNorm1d ,
38- nn .InstanceNorm2d ,
39- nn .InstanceNorm3d ,
40- nn .LazyInstanceNorm1d ,
41- nn .LazyInstanceNorm2d ,
42- nn .LazyInstanceNorm3d ,
43- )
44-
4513
4614class Engine :
4715 """
@@ -50,7 +18,7 @@ class Engine:
5018 Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_ but goes even further:
5119
5220 * It works for any computation graph (not just sequential models).
53- * It is optimized for batched computations (as long as ``batch_dim`` is specified) .
21+ * It is highly optimized for batched computations but also supports non-batched computations .
5422 * It supports any shape of tensor to differentiate (not just a vector of losses). For more
5523 details about this, look at :meth:`Engine.compute_gramian`.
5624
@@ -66,10 +34,6 @@ class Engine:
6634 :param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
6735 Several modules can be provided, but it's important that none of them is a child module of
6836 another of them.
69- :param batch_dim: If the modules work with batches and process each batch element independently,
70- then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
71- memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
72- the batch dimension of the output tensor, if any.
7337
7438 .. admonition::
7539 Example
@@ -97,7 +61,7 @@ class Engine:
9761 weighting = UPGradWeighting()
9862
9963 # Create the engine before the backward pass, and only once.
100- engine = Engine(model, batch_dim=0 )
64+ engine = Engine(model)
10165
10266 for input, target in zip(inputs, targets):
10367 output = model(input).squeeze(dim=1) # shape: [16]
@@ -113,48 +77,13 @@ class Engine:
11377 since the Jacobian never has to be entirely in memory, it is often much more
11478 memory-efficient, and thus typically faster, to use the Gramian-based approach.
11579
80+ .. warning:: For autogram to be fast and low-memory, it is very important to use only batched
81+ modules (i.e. modules that treat each element of the batch independently). For instance,
82+ BatchNorm is not a batched module because it computes some statistics over the batch.
83+
11684 .. warning::
117- When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:
118-
119- * They should treat the elements of the batch independently. Most common layers respect
120- this, but for example `BatchNorm
121- <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
122- computes some average and standard deviation over the elements of the batch).
123- * Their inputs and outputs can be anything, but each input tensor and each output tensor
124- must be batched on its first dimension. When available (e.g. in `Transformers
125- <https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
126- `MultiheadAttention
127- <https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
128- etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
129- <https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
130- because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
131- * They should not perform in-place operations on tensors (for instance you should not use
132- ``track_running_stats=True`` in normalization layers).
133- * They should not have side effects during the forward pass (since their forward pass will
134- be called twice, the side effects could be different from what's expected).
135- * If they have some randomness during the forward pass, they should not have direct
136- trainable parameters. For this reason,
137- `Transformers
138- <https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
139- dropout function (rather than a `Dropout
140- <https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
141- module with some trainable parameters, has to be used with
142- ``dropout=0.0``. Note that a `Dropout
143- <https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
144- entirely supported and should be preferred. It is also perfectly fine for random modules
145- to have child modules that have trainable parameters, so if you have a random module with
146- some direct parameters, a simple fix is to wrap these parameters into a child module.
147-
148- If you're building your own architecture, respecting those criteria should be quite easy.
149- However, if you're using an existing architecture, you may have to modify it to make it
150- compatible with the autogram engine. For instance, you may want to replace `BatchNorm2d
151- <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ layers by
152- `GroupNorm <https://docs.pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html>`_ or
153- `InstanceNorm2d
154- <https://docs.pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html>`_ layers.
155-
156- The alternative is to use ``batch_dim=None``, but it's not recommended since it will
157- increase memory usage by a lot and thus typically slow down computation.
85+ `RNNs <https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ may not be
86+ supported on cuda because vmap is not implemented for RNN on that device.
15887
15988 .. warning::
16089 Parent modules should call their child modules directly rather than using their child
@@ -177,14 +106,9 @@ class Engine:
177106 another child module to avoid the slow-down.
178107 """
179108
180- def __init__ (
181- self ,
182- * modules : nn .Module ,
183- batch_dim : int | None ,
184- ):
109+ def __init__ (self , * modules : nn .Module ):
185110 self ._gramian_accumulator = GramianAccumulator ()
186111 self ._target_edges = EdgeRegistry ()
187- self ._batch_dim = batch_dim
188112 self ._module_hook_manager = ModuleHookManager (self ._target_edges , self ._gramian_accumulator )
189113 self ._gramian_computers = dict [nn .Module , GramianComputer ]()
190114
@@ -193,7 +117,6 @@ def __init__(
193117
194118 def _hook_module_recursively (self , module : nn .Module ) -> None :
195119 if any (p .requires_grad for p in module .parameters (recurse = False )):
196- self ._check_module_is_compatible (module )
197120 gramian_computer = self ._make_gramian_computer (module )
198121 self ._gramian_computers [module ] = gramian_computer
199122 self ._module_hook_manager .hook_module (module , gramian_computer )
@@ -202,36 +125,11 @@ def _hook_module_recursively(self, module: nn.Module) -> None:
202125 self ._hook_module_recursively (child )
203126
204127 def _make_gramian_computer (self , module : nn .Module ) -> GramianComputer :
205- jacobian_computer : JacobianComputer
206- if self ._batch_dim is not None :
207- jacobian_computer = FunctionalJacobianComputer (module )
208- else :
209- jacobian_computer = AutogradJacobianComputer (module )
128+ jacobian_computer = AutogradJacobianComputer (module )
210129 gramian_computer = JacobianBasedGramianComputerWithCrossTerms (jacobian_computer )
211130
212131 return gramian_computer
213132
214- def _check_module_is_compatible (self , module : nn .Module ) -> None :
215- if self ._batch_dim is not None :
216- if isinstance (module , _MODULES_INCOMPATIBLE_WITH_BATCHED ):
217- raise ValueError (
218- f"Found a module of type { type (module )} , which is incompatible with the "
219- f"autogram engine when `batch_dim` is not `None`. The incompatible module types"
220- f" are { _MODULES_INCOMPATIBLE_WITH_BATCHED } (and their subclasses). The "
221- f"recommended fix is to replace incompatible layers by something else (e.g. "
222- f"BatchNorm by InstanceNorm). If you really can't and performance is not a "
223- f"priority, you may also just set `batch_dim=None` when creating the engine."
224- )
225- if isinstance (module , _TRACK_RUNNING_STATS_MODULE_TYPES ) and module .track_running_stats :
226- raise ValueError (
227- f"Found a module of type { type (module )} , with `track_running_stats=True`, which"
228- f" is incompatible with the autogram engine when `batch_dim` is not `None`, due"
229- f" to performing in-place operations on tensors and having side-effects during "
230- f"the forward pass. Try setting `track_running_stats` to `False`. If you really"
231- f" can't and performance is not a priority, you may also just set "
232- f"`batch_dim=None` when creating the engine."
233- )
234-
235133 def compute_gramian (self , output : Tensor ) -> Tensor :
236134 r"""
237135 Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
@@ -261,33 +159,31 @@ def compute_gramian(self, output: Tensor) -> Tensor:
261159 - etc.
262160 """
263161
264- if self ._batch_dim is not None :
265- # move batched dim to the end
266- ordered_output = output . movedim ( self . _batch_dim , - 1 )
267- ordered_shape = list (ordered_output . shape )
268- batch_size = ordered_shape [ - 1 ]
269- has_non_batch_dim = len ( ordered_shape ) > 1
270- target_shape = [ batch_size ]
271- else :
272- ordered_output = output
273- ordered_shape = list ( ordered_output . shape )
274- has_non_batch_dim = len ( ordered_shape ) > 0
275- target_shape = []
162+ self ._module_hook_manager . gramian_accumulation_phase . value = True
163+
164+ try :
165+ leaf_targets = list (self . _target_edges . get_leaf_edges ({ get_gradient_edge ( output )}) )
166+
167+ def differentiation ( _grad_output : Tensor ) -> tuple [ Tensor , ...]:
168+ return torch . autograd . grad (
169+ outputs = output ,
170+ inputs = leaf_targets ,
171+ grad_outputs = _grad_output ,
172+ retain_graph = True ,
173+ )
276174
277- if has_non_batch_dim :
278- target_shape = [ - 1 ] + target_shape
175+ output_dims = list ( range ( output . ndim ))
176+ jac_output = _make_initial_jac_output ( output )
279177
280- reshaped_output = ordered_output .reshape (target_shape )
281- # There are four different cases for the shape of reshaped_output:
282- # - Not batched and not non-batched: scalar of shape []
283- # - Batched only: vector of shape [batch_size]
284- # - Non-batched only: vector of shape [dim]
285- # - Batched and non-batched: matrix of shape [dim, batch_size]
178+ vmapped_diff = differentiation
179+ for _ in output_dims :
180+ vmapped_diff = vmap (vmapped_diff )
286181
287- self . _module_hook_manager . gramian_accumulation_phase . value = True
182+ _ = vmapped_diff ( jac_output )
288183
289- try :
290- square_gramian = self ._compute_square_gramian (reshaped_output , has_non_batch_dim )
184+ # If the gramian were None, then leaf_targets would be empty, so autograd.grad would
185+ # have failed. So gramian is necessarily a valid Tensor here.
186+ gramian = cast (Tensor , self ._gramian_accumulator .gramian )
291187 finally :
292188 # Reset everything that has a state, even if the previous call raised an exception
293189 self ._module_hook_manager .gramian_accumulation_phase .value = False
@@ -296,40 +192,16 @@ def compute_gramian(self, output: Tensor) -> Tensor:
296192 for gramian_computer in self ._gramian_computers .values ():
297193 gramian_computer .reset ()
298194
299- unordered_gramian = reshape_gramian (square_gramian , ordered_shape )
300-
301- if self ._batch_dim is not None :
302- gramian = movedim_gramian (unordered_gramian , [- 1 ], [self ._batch_dim ])
303- else :
304- gramian = unordered_gramian
305-
306195 return gramian
307196
308- def _compute_square_gramian (self , output : Tensor , has_non_batch_dim : bool ) -> Tensor :
309- leaf_targets = list (self ._target_edges .get_leaf_edges ({get_gradient_edge (output )}))
310-
311- def differentiation (_grad_output : Tensor ) -> tuple [Tensor , ...]:
312- return torch .autograd .grad (
313- outputs = output ,
314- inputs = leaf_targets ,
315- grad_outputs = _grad_output ,
316- retain_graph = True ,
317- )
318-
319- if has_non_batch_dim :
320- # There is one non-batched dimension, it is the first one
321- non_batch_dim_len = output .shape [0 ]
322- identity_matrix = torch .eye (non_batch_dim_len , device = output .device , dtype = output .dtype )
323- ones = torch .ones_like (output [0 ])
324- jac_output = torch .einsum ("ij, ... -> ij..." , identity_matrix , ones )
325-
326- _ = vmap (differentiation )(jac_output )
327- else :
328- grad_output = torch .ones_like (output )
329- _ = differentiation (grad_output )
330197
331- # If the gramian were None, then leaf_targets would be empty, so autograd.grad would
332- # have failed. So gramian is necessarily a valid Tensor here.
333- gramian = cast (Tensor , self ._gramian_accumulator .gramian )
198+ def _make_initial_jac_output (output : Tensor ) -> Tensor :
199+ if output .ndim == 0 :
200+ return torch .ones_like (output )
201+ p_index_ranges = [torch .arange (s , device = output .device ) for s in output .shape ]
202+ p_indices_grid = torch .meshgrid (* p_index_ranges , indexing = "ij" )
203+ v_indices_grid = p_indices_grid + p_indices_grid
334204
335- return gramian
205+ res = torch .zeros (list (output .shape ) * 2 , device = output .device , dtype = output .dtype )
206+ res [v_indices_grid ] = 1.0
207+ return res
0 commit comments