Skip to content

Commit 15760a5

Browse files
kainzhongksivaman
andauthored
[PyTorch] Add an API restore from function context to ensure tensors are detached (NVIDIA#2772)
[PyTorch] Change the restore tensor API to ensure tensors are detached from ctx Signed-off-by: Kaining Zhong <kainingz@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 3e61687 commit 15760a5

11 files changed

Lines changed: 41 additions & 36 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from transformer_engine.pytorch.quantized_tensor import (
5050
Quantizer,
5151
prepare_for_saving,
52-
restore_from_saved,
52+
restore_from_func_ctx,
5353
)
5454

5555
_current_file = pathlib.Path(__file__).resolve()
@@ -2701,10 +2701,7 @@ def forward(
27012701
@staticmethod
27022702
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
27032703
with torch.cuda.nvtx.range("_DPA"):
2704-
saved_tensors = ctx.saved_tensors
2705-
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
2706-
ctx.tensor_objects, saved_tensors
2707-
)
2704+
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx)
27082705

27092706
proj_dgrad = ctx.dO_quantizer(grad_output)
27102707
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)

transformer_engine/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from transformer_engine.pytorch.quantized_tensor import Quantizer
6969
from transformer_engine.pytorch.quantized_tensor import prepare_for_saving
7070
from transformer_engine.pytorch.quantized_tensor import restore_from_saved
71+
from transformer_engine.pytorch.quantized_tensor import restore_from_func_ctx
7172
from transformer_engine.pytorch.tensor import Float8Quantizer
7273
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
7374
from transformer_engine.pytorch.tensor import MXFP8Quantizer

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from transformer_engine.pytorch.quantized_tensor import (
3333
QuantizedTensorStorage,
3434
prepare_for_saving,
35-
restore_from_saved,
35+
restore_from_func_ctx,
3636
)
3737
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
3838
from transformer_engine.pytorch.constants import (
@@ -1477,7 +1477,7 @@ def backward(ctx, d_out, *_args):
14771477
cu_seqlens_q_padded,
14781478
cu_seqlens_kv_padded,
14791479
*other_tensors,
1480-
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
1480+
) = restore_from_func_ctx(ctx)
14811481

14821482
aux_ctx_tensors = other_tensors
14831483

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from transformer_engine.pytorch.quantized_tensor import (
4040
prepare_for_saving,
41-
restore_from_saved,
41+
restore_from_func_ctx,
4242
)
4343

4444
# Import attention utils
@@ -2085,7 +2085,7 @@ def backward(ctx, dout, *_args):
20852085
cu_seqlens_q_padded,
20862086
cu_seqlens_kv_padded,
20872087
*other_tensors,
2088-
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
2088+
) = restore_from_func_ctx(ctx)
20892089
cu_seqlens_q_per_step = other_tensors[:cp_size]
20902090
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
20912091
rng_states = other_tensors[cp_size * 2 : cp_size * 3]
@@ -3675,7 +3675,7 @@ def backward(ctx, dout, *_args):
36753675
cu_seqlens_q_padded,
36763676
cu_seqlens_kv_padded,
36773677
*aux_ctx_tensors,
3678-
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
3678+
) = restore_from_func_ctx(ctx)
36793679

36803680
qkv_format = ctx.qkv_format
36813681
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
QuantizedTensorStorage,
5050
Quantizer,
5151
prepare_for_saving,
52-
restore_from_saved,
52+
restore_from_func_ctx,
5353
)
5454
from ...debug.pytorch.debug_quantization import DebugQuantizer
5555
from ...debug.pytorch.debug_state import TEDebugState
@@ -316,7 +316,7 @@ def forward(
316316
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
317317
# pylint: disable=missing-function-docstring
318318
with get_nvtx_range_context("_GroupedLinear_backward"):
319-
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
319+
saved_tensors = restore_from_func_ctx(ctx)
320320
N = ctx.num_gemms
321321
inputmats = saved_tensors[:N]
322322
weights = saved_tensors[N : 2 * N]

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
QuantizedTensorStorage,
6161
Quantizer,
6262
prepare_for_saving,
63-
restore_from_saved,
63+
restore_from_func_ctx,
6464
)
6565
from ...debug.pytorch.debug_state import TEDebugState
6666
from ..tensor.mxfp8_tensor import MXFP8Quantizer
@@ -546,7 +546,6 @@ def backward(
546546
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
547547

548548
with get_nvtx_range_context("_LayerNormLinear_backward"):
549-
saved_tensors = ctx.saved_tensors
550549
( # pylint: disable=unbalanced-tuple-unpacking
551550
inputmat,
552551
weight,
@@ -556,11 +555,7 @@ def backward(
556555
ln_out,
557556
mu,
558557
rsigma,
559-
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
560-
561-
# Delete the references to tensor objects once they've been consumed
562-
# by the `restore_from_saved` method to construct back the actual tensors.
563-
ctx.tensor_objects = None
558+
) = restore_from_func_ctx(ctx)
564559

565560
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
566561
main_grad = (

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
QuantizedTensorStorage,
8181
Quantizer,
8282
prepare_for_saving,
83-
restore_from_saved,
83+
restore_from_func_ctx,
8484
)
8585
from ..cpp_extensions import (
8686
general_gemm,
@@ -898,11 +898,7 @@ def forward(
898898
def _recompute(ctx):
899899
# pylint: disable=missing-function-docstring
900900

901-
saved_tensors = ctx.saved_tensors
902-
tensors = restore_from_saved(ctx.tensor_objects, saved_tensors)
903-
# Delete the references to tensor objects once they've been consumed
904-
# by the `restore_from_saved` method to construct back the actual tensors.
905-
ctx.tensor_objects = None
901+
tensors = restore_from_func_ctx(ctx)
906902

907903
if ctx.checkpoint: # do recomputation from the original args
908904

transformer_engine/pytorch/module/linear.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
QuantizedTensorStorage,
6262
Quantizer,
6363
prepare_for_saving,
64-
restore_from_saved,
64+
restore_from_func_ctx,
6565
)
6666
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
6767
from ..tensor.mxfp8_tensor import MXFP8Quantizer
@@ -501,15 +501,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
501501
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
502502

503503
with get_nvtx_range_context("_Linear_backward"):
504-
saved_tensors = ctx.saved_tensors
505504
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
506-
restore_from_saved(ctx.tensor_objects, saved_tensors)
505+
restore_from_func_ctx(ctx)
507506
)
508507

509-
# Delete the references to tensor objects once they've been consumed
510-
# by the `restore_from_saved` method to construct back the actual tensors.
511-
ctx.tensor_objects = None
512-
513508
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
514509
main_grad = (
515510
ctx.main_grad_func()

transformer_engine/pytorch/ops/fuser.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313

1414
from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
15-
from ..quantized_tensor import prepare_for_saving, restore_from_saved
15+
from ..quantized_tensor import prepare_for_saving, restore_from_func_ctx
1616
from .op import (
1717
BasicOperation,
1818
FusibleOperation,
@@ -212,8 +212,7 @@ def backward(
212212
basic_op_ctxs = func_ctx.basic_op_ctxs
213213

214214
# Restore saved tensors
215-
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
216-
func_ctx.tensor_objects = None
215+
saved_tensors = restore_from_func_ctx(func_ctx)
217216

218217
# Unflatten list of saved tensors
219218
for ctx in basic_op_ctxs:

transformer_engine/pytorch/quantized_tensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def restore_from_saved(
165165
list[Optional[torch.Tensor]],
166166
]
167167
):
168-
"""Recombine the tensor data and metadata during backward pass."""
168+
"""Recombine the tensor data and metadata during backward pass.
169+
Note: please use `restore_from_func_ctx` instead if you are restoring tensors from a function context to make sure tensor_objects is detached and its memory can be freed
170+
"""
169171
tensor_objects = []
170172
for tensor in tensors:
171173
if tensor is None or isinstance(tensor, torch.Tensor):
@@ -180,6 +182,24 @@ def restore_from_saved(
180182
return tensor_objects
181183

182184

185+
def restore_from_func_ctx(ctx: torch.autograd.function.FunctionCtx, return_saved_tensors=False) -> (
186+
list[Optional[torch.Tensor | QuantizedTensorStorage]]
187+
| tuple[
188+
list[Optional[torch.Tensor | QuantizedTensorStorage]],
189+
list[Optional[torch.Tensor]],
190+
]
191+
):
192+
"""Recombine the tensor data and metadata during backward pass and delete tensor objects attached to function context."""
193+
if not hasattr(ctx, "tensor_objects") or ctx.tensor_objects is None:
194+
raise AttributeError("ctx must have .tensor_objects to restore saved tensors")
195+
out = restore_from_saved(
196+
ctx.tensor_objects, ctx.saved_tensors, return_saved_tensors=return_saved_tensors
197+
)
198+
# Delete the references to tensor objects once they've been consumed by the `restore_from_saved` method to construct back the actual tensors.
199+
ctx.tensor_objects = None
200+
return out
201+
202+
183203
class Quantizer(abc.ABC):
184204
"""Builder class for quantized tensors.
185205

0 commit comments

Comments
 (0)