Skip to content

Commit 135d29f

Browse files
Allow force_recompute_layer to be an int, for more finegrained checkpointing control. (#889)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 68ea9eb commit 135d29f

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/tabpfn/architectures/interface.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,18 @@ class PerformanceOptions:
7575
throughput by reducing memory pressure and the number of CPU<->GPU
7676
synchronisation points required for memory allocations."""
7777

78-
force_recompute_layer: bool = False
78+
force_recompute_layer: bool | int = False
7979
"""Enable activation checkpointing (gradient recomputation) for all layers.
8080
8181
When ``True``, intermediate activations are not stored during the forward pass;
8282
instead they are recomputed from scratch during the backward pass. This trades
8383
compute for memory and is useful when training with very large context sizes.
8484
Has no effect during inference (``torch.no_grad`` / ``torch.inference_mode``).
85+
86+
Some models support passing an integer value, where 0 corresponds to no
87+
checkpointing, and higher values correspond to more aggressive checkpointing.
88+
This allows for finer tuning of the compute/memory tradeoff. Models will clip the
89+
value to their maximum supported level of checkpointing.
8590
"""
8691

8792
use_chunkwise_inference: bool = False

0 commit comments

Comments
 (0)