File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -75,13 +75,18 @@ class PerformanceOptions:
7575 throughput by reducing memory pressure and the number of CPU<->GPU
7676 synchronisation points required for memory allocations."""
7777
78- force_recompute_layer : bool = False
78+ force_recompute_layer : bool | int = False
7979 """Enable activation checkpointing (gradient recomputation) for all layers.
8080
8181 When ``True``, intermediate activations are not stored during the forward pass;
8282 instead they are recomputed from scratch during the backward pass. This trades
8383 compute for memory and is useful when training with very large context sizes.
8484 Has no effect during inference (``torch.no_grad`` / ``torch.inference_mode``).
85+
86+ Some models support passing an integer value, where 0 corresponds to no
87+ checkpointing, and higher values correspond to more aggressive checkpointing.
88+ This allows for finer tuning of the compute/memory tradeoff. Models will clip the
89+ value to their maximum supported level of checkpointing.
8590 """
8691
8792 use_chunkwise_inference : bool = False
You can’t perform that action at this time.
0 commit comments