Skip to content

Commit 4d0bbae

Browse files
committed
Revert "[PyTorch] Add op-level activation offload opt-out API (#3108)"
This reverts commit 9b06f26. Signed-off-by: Tim Moon <tmoon@nvidia.com>
1 parent 9b06f26 commit 4d0bbae

15 files changed

Lines changed: 61 additions & 160 deletions

tests/pytorch/test_fusible_ops.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -634,40 +634,6 @@ def test_pyt_autocast(
634634
assert x.grad.dtype == model_dtype
635635
assert op.weight.grad.dtype == model_dtype
636636

637-
def test_activation_offloading_policy(self, monkeypatch):
638-
"""Test opt-out API for activation CPU offloading."""
639-
import transformer_engine.pytorch.ops.op as op_module
640-
641-
calls = []
642-
tensor = torch.empty(1)
643-
tensor_id = id(tensor)
644-
op = te_ops.Identity()
645-
646-
monkeypatch.setattr(
647-
op_module,
648-
"mark_activation_offload",
649-
lambda *tensors: calls.append(("mark", [id(t) for t in tensors])),
650-
)
651-
monkeypatch.setattr(
652-
op_module,
653-
"mark_not_offload",
654-
lambda *tensors: calls.append(("skip", [id(t) for t in tensors])),
655-
)
656-
monkeypatch.setattr(op_module, "is_cpu_offload_enabled", lambda: True)
657-
658-
op.mark_for_cpu_offload_if_needed(tensor, None)
659-
assert calls == [("mark", [tensor_id])]
660-
661-
calls.clear()
662-
op.set_activation_offloading(False)
663-
op.mark_for_cpu_offload_if_needed(tensor)
664-
assert calls == [("skip", [tensor_id])]
665-
666-
calls.clear()
667-
op.set_activation_offloading(True)
668-
op.mark_for_cpu_offload_if_needed(tensor)
669-
assert calls == [("mark", [tensor_id])]
670-
671637

672638
class TestBasicOps:
673639
"""Tests for individual operations"""

transformer_engine/pytorch/ops/basic/activation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import transformer_engine_torch as tex
1515
from ...constants import DType
16+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1617
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
1718
from ...utils import clear_tensor_data
1819
from ..op import BasicOperation, OperationContext
@@ -113,7 +114,8 @@ def op_forward(
113114

114115
# Save state for backward pass
115116
if ctx.requires_grad:
116-
self.mark_for_cpu_offload_if_needed(x)
117+
if is_cpu_offload_enabled():
118+
mark_activation_offload(x)
117119
ctx.save_for_backward(x)
118120
ctx.dtype = dtype
119121
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
@@ -412,7 +414,8 @@ def fuser_forward(
412414

413415
ctx = basic_op_ctxs[0]
414416
if ctx.requires_grad:
415-
self.mark_for_cpu_offload_if_needed(x)
417+
if is_cpu_offload_enabled():
418+
mark_activation_offload(x)
416419
ctx.input_requires_grad = True
417420
ctx.extra_input_requires_grad = extra_input.requires_grad
418421
ctx.dtype = dtype

transformer_engine/pytorch/ops/basic/basic_linear.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414

1515
from ...cpp_extensions import general_gemm
16+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1617
from ...distributed import (
1718
CudaRNGStatesTracker,
1819
gather_along_first_dim,
@@ -1048,13 +1049,11 @@ def op_forward(
10481049
else:
10491050
saved_input = x_local
10501051
saved_weight = w
1051-
1052-
# Activation CPU offloading
1053-
# Note: No special CPU offloading logic is needed for weights. saved_weight is
1054-
# either self.weight (nn.Parameter, auto-excluded from offload) or a
1055-
# workspace freshly created each forward pass.
1056-
self.mark_for_cpu_offload_if_needed(saved_input)
1057-
1052+
if is_cpu_offload_enabled():
1053+
# No special CPU offloading logic is needed for weights. saved_weight is
1054+
# either self.weight (nn.Parameter, auto-excluded from offload) or a
1055+
# workspace freshly created each forward pass.
1056+
mark_activation_offload(saved_input)
10581057
ctx.save_for_backward(saved_input, saved_weight)
10591058
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
10601059
ctx.backward_override = backward_override

transformer_engine/pytorch/ops/basic/dropout.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import transformer_engine_torch as tex
12+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1213
from ...tensor import Quantizer
1314
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
1415
from .._common import maybe_autocast_dtype, maybe_dequantize
@@ -70,7 +71,8 @@ def op_forward(
7071

7172
# Save context for backward
7273
if ctx.requires_grad:
73-
self.mark_for_cpu_offload_if_needed(mask)
74+
if is_cpu_offload_enabled():
75+
mark_activation_offload(mask)
7476
ctx.save_for_backward(mask)
7577
ctx.impl = impl
7678
ctx.dropout_probability = self.dropout_probability

transformer_engine/pytorch/ops/basic/grouped_linear.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_2X_ACC_DGRAD,
2424
_2X_ACC_WGRAD,
2525
)
26-
from ...cpu_offload import is_cpu_offload_enabled, start_offload
26+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
2727
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
2828
from ...quantized_tensor import QuantizedTensorStorage
2929
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
@@ -1032,14 +1032,19 @@ def fuser_forward_save_ctx(
10321032
# Note: No special logic is needed for weights. They are
10331033
# either nn.Parameter (auto-excluded from offload) or are
10341034
# temporary workspaces freshly created in each forward pass.
1035-
saved = tensors_to_save[0]
1036-
offset = 4 if self._scale_bias else 3
1037-
if use_grouped_tensor_path:
1038-
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
1039-
self.mark_for_cpu_offload_if_needed(saved[offset])
1040-
else:
1041-
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
1042-
self.mark_for_cpu_offload_if_needed(saved[offset : offset + self.num_groups])
1035+
if is_cpu_offload_enabled():
1036+
saved = tensors_to_save[0]
1037+
offset = 4 if self._scale_bias else 3
1038+
if use_grouped_tensor_path:
1039+
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
1040+
grouped_x = saved[offset]
1041+
if grouped_x is not None:
1042+
mark_activation_offload(grouped_x)
1043+
else:
1044+
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
1045+
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
1046+
if live_xs:
1047+
mark_activation_offload(*live_xs)
10431048

10441049
ctx.save_for_backward(*tensors_to_save[0])
10451050

transformer_engine/pytorch/ops/basic/l2normalization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from ...torch_version import torch_version
14+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1415
from ...jit import (
1516
l2normalization_fused,
1617
l2normalization_fwd_fused,
@@ -101,7 +102,8 @@ def op_forward(
101102

102103
# Save state for backward pass
103104
if requires_grad:
104-
self.mark_for_cpu_offload_if_needed(x, rsqrt_norm)
105+
if is_cpu_offload_enabled():
106+
mark_activation_offload(x, rsqrt_norm)
105107
ctx.save_for_backward(x, rsqrt_norm)
106108

107109
return y

transformer_engine/pytorch/ops/basic/layer_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
1616
from ...constants import TE_DType
17+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1718
from ...export import is_in_onnx_export_mode
1819
from ...tensor import Quantizer
1920
from ...utils import (
@@ -215,7 +216,8 @@ def op_forward(
215216

216217
# Save state for backward pass
217218
if ctx.requires_grad:
218-
self.mark_for_cpu_offload_if_needed(x, means, rstdevs)
219+
if is_cpu_offload_enabled():
220+
mark_activation_offload(x, means, rstdevs)
219221
ctx.save_for_backward(x, means, rstdevs)
220222
ctx.dtype = dtype
221223

transformer_engine/pytorch/ops/basic/rmsnorm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
1616
from ...constants import TE_DType
17+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1718
from ...export import is_in_onnx_export_mode
1819
from ...tensor import Quantizer
1920
from ...utils import (
@@ -196,7 +197,8 @@ def op_forward(
196197

197198
# Save state for backward pass
198199
if ctx.requires_grad:
199-
self.mark_for_cpu_offload_if_needed(x, rstdevs)
200+
if is_cpu_offload_enabled():
201+
mark_activation_offload(x, rstdevs)
200202
ctx.save_for_backward(x, rstdevs)
201203
ctx.dtype = dtype
202204

transformer_engine/pytorch/ops/basic/swiglu.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import transformer_engine_torch as tex
1414
from ...constants import DType
15+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1516
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
1617
from ...utils import clear_tensor_data
1718
from ..op import BasicOperation, OperationContext
@@ -126,7 +127,8 @@ def op_forward(
126127

127128
# Save state for backward pass
128129
if ctx.requires_grad:
129-
self.mark_for_cpu_offload_if_needed(input_)
130+
if is_cpu_offload_enabled():
131+
mark_activation_offload(input_)
130132
ctx.save_for_backward(input_)
131133
ctx.dtype = dtype
132134
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
@@ -309,7 +311,8 @@ def op_forward(
309311

310312
# Save state for backward pass
311313
if ctx.requires_grad:
312-
self.mark_for_cpu_offload_if_needed(x)
314+
if is_cpu_offload_enabled():
315+
mark_activation_offload(x)
313316
ctx.save_for_backward(x)
314317
ctx.dtype = dtype
315318
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
@@ -459,7 +462,8 @@ def fuser_forward(
459462
# Save state for backward pass
460463
ctx = basic_op_ctxs[0]
461464
if ctx.requires_grad:
462-
self.mark_for_cpu_offload_if_needed(input_)
465+
if is_cpu_offload_enabled():
466+
mark_activation_offload(input_)
463467
ctx.input_requires_grad = True
464468
ctx.extra_input_requires_grad = extra_input.requires_grad
465469
ctx.dtype = dtype

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414

1515
import transformer_engine_torch as tex
16-
from ...cpu_offload import is_cpu_offload_enabled, start_offload
16+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
1717
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
1818
from ...quantization import Recipe
1919
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
@@ -170,7 +170,7 @@ def fuser_forward(
170170
basic_op_kwargs: list[dict[str, Any]],
171171
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
172172
# Get basic operations
173-
fc1_op, activation_op, fc2_op = self.basic_ops
173+
fc1_op, _, fc2_op = self.basic_ops
174174
fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs
175175

176176
# Tensor properties
@@ -726,6 +726,7 @@ def fuser_forward(
726726
# Save state for backward pass
727727
if requires_grad:
728728
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
729+
activation_op = self.basic_ops[1]
729730
cpu_offloading = is_cpu_offload_enabled()
730731
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
731732
activation_recompute_in_mlp = bool(
@@ -753,9 +754,7 @@ def fuser_forward(
753754
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
754755
]
755756
start_offload(*activation_tensors)
756-
fc1_op.mark_for_cpu_offload_if_needed(grouped_fc1_x)
757-
activation_op.mark_for_cpu_offload_if_needed(activation_in)
758-
fc2_op.mark_for_cpu_offload_if_needed(saved_grouped_fc2_x)
757+
mark_activation_offload(*activation_tensors)
759758

760759
# FC1 saved-tensor layout.
761760
# [split_sizes, base_split_offsets, split_points,

0 commit comments

Comments
 (0)