Skip to content

Commit 9b06f26

Browse files
lhb8125timmoon10codex
authored
[PyTorch] Add op-level activation offload opt-out API (#3108)
* Add TE op CPU offload opt-out API Signed-off-by: hongbinl <hongbinl@nvidia.com> * Rename op API to activation offloading Signed-off-by: hongbinl <hongbinl@nvidia.com> * Move CPU offload gating to TE op call sites Signed-off-by: hongbinl <hongbinl@nvidia.com> * Preserve grouped linear offload start semantics Signed-off-by: hongbinl <hongbinl@nvidia.com> * Use setter for activation offload policy Signed-off-by: hongbinl <hongbinl@nvidia.com> * Limit activation offload policy helper to marking Signed-off-by: hongbinl <hongbinl@nvidia.com> * Move CPU offload imports to op module scope Signed-off-by: hongbinl <hongbinl@nvidia.com> * Patch activation offload test bound symbols Signed-off-by: hongbinl <hongbinl@nvidia.com> * Refactor base class offloading infrastructure Handle inclusion and exclusion in same function. Check whether CPU offloading is enabled internally. Tweak documentation and style. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Propagate activation offload policy helper Use BasicOperation.mark_for_cpu_offload_if_needed at op call sites and keep explicit offload synchronization checks where needed. Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Move test into TestFuser suite Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug failure with grouped linear Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: hongbinl <hongbinl@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: OpenAI Codex <codex@openai.com>
1 parent 20e185c commit 9b06f26

15 files changed

Lines changed: 160 additions & 61 deletions

tests/pytorch/test_fusible_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,40 @@ def test_pyt_autocast(
634634
assert x.grad.dtype == model_dtype
635635
assert op.weight.grad.dtype == model_dtype
636636

637+
def test_activation_offloading_policy(self, monkeypatch):
638+
"""Test opt-out API for activation CPU offloading."""
639+
import transformer_engine.pytorch.ops.op as op_module
640+
641+
calls = []
642+
tensor = torch.empty(1)
643+
tensor_id = id(tensor)
644+
op = te_ops.Identity()
645+
646+
monkeypatch.setattr(
647+
op_module,
648+
"mark_activation_offload",
649+
lambda *tensors: calls.append(("mark", [id(t) for t in tensors])),
650+
)
651+
monkeypatch.setattr(
652+
op_module,
653+
"mark_not_offload",
654+
lambda *tensors: calls.append(("skip", [id(t) for t in tensors])),
655+
)
656+
monkeypatch.setattr(op_module, "is_cpu_offload_enabled", lambda: True)
657+
658+
op.mark_for_cpu_offload_if_needed(tensor, None)
659+
assert calls == [("mark", [tensor_id])]
660+
661+
calls.clear()
662+
op.set_activation_offloading(False)
663+
op.mark_for_cpu_offload_if_needed(tensor)
664+
assert calls == [("skip", [tensor_id])]
665+
666+
calls.clear()
667+
op.set_activation_offloading(True)
668+
op.mark_for_cpu_offload_if_needed(tensor)
669+
assert calls == [("mark", [tensor_id])]
670+
637671

638672
class TestBasicOps:
639673
"""Tests for individual operations"""

transformer_engine/pytorch/ops/basic/activation.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import transformer_engine_torch as tex
1515
from ...constants import DType
16-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1716
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
1817
from ...utils import clear_tensor_data
1918
from ..op import BasicOperation, OperationContext
@@ -114,8 +113,7 @@ def op_forward(
114113

115114
# Save state for backward pass
116115
if ctx.requires_grad:
117-
if is_cpu_offload_enabled():
118-
mark_activation_offload(x)
116+
self.mark_for_cpu_offload_if_needed(x)
119117
ctx.save_for_backward(x)
120118
ctx.dtype = dtype
121119
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
@@ -414,8 +412,7 @@ def fuser_forward(
414412

415413
ctx = basic_op_ctxs[0]
416414
if ctx.requires_grad:
417-
if is_cpu_offload_enabled():
418-
mark_activation_offload(x)
415+
self.mark_for_cpu_offload_if_needed(x)
419416
ctx.input_requires_grad = True
420417
ctx.extra_input_requires_grad = extra_input.requires_grad
421418
ctx.dtype = dtype

transformer_engine/pytorch/ops/basic/basic_linear.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414

1515
from ...cpp_extensions import general_gemm
16-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1716
from ...distributed import (
1817
CudaRNGStatesTracker,
1918
gather_along_first_dim,
@@ -1049,11 +1048,13 @@ def op_forward(
10491048
else:
10501049
saved_input = x_local
10511050
saved_weight = w
1052-
if is_cpu_offload_enabled():
1053-
# No special CPU offloading logic is needed for weights. saved_weight is
1054-
# either self.weight (nn.Parameter, auto-excluded from offload) or a
1055-
# workspace freshly created each forward pass.
1056-
mark_activation_offload(saved_input)
1051+
1052+
# Activation CPU offloading
1053+
# Note: No special CPU offloading logic is needed for weights. saved_weight is
1054+
# either self.weight (nn.Parameter, auto-excluded from offload) or a
1055+
# workspace freshly created each forward pass.
1056+
self.mark_for_cpu_offload_if_needed(saved_input)
1057+
10571058
ctx.save_for_backward(saved_input, saved_weight)
10581059
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
10591060
ctx.backward_override = backward_override

transformer_engine/pytorch/ops/basic/dropout.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
import transformer_engine_torch as tex
12-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1312
from ...tensor import Quantizer
1413
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
1514
from .._common import maybe_autocast_dtype, maybe_dequantize
@@ -71,8 +70,7 @@ def op_forward(
7170

7271
# Save context for backward
7372
if ctx.requires_grad:
74-
if is_cpu_offload_enabled():
75-
mark_activation_offload(mask)
73+
self.mark_for_cpu_offload_if_needed(mask)
7674
ctx.save_for_backward(mask)
7775
ctx.impl = impl
7876
ctx.dropout_probability = self.dropout_probability

transformer_engine/pytorch/ops/basic/grouped_linear.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_2X_ACC_DGRAD,
2424
_2X_ACC_WGRAD,
2525
)
26-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
26+
from ...cpu_offload import is_cpu_offload_enabled, start_offload
2727
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
2828
from ...quantized_tensor import QuantizedTensorStorage
2929
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
@@ -1032,19 +1032,14 @@ def fuser_forward_save_ctx(
10321032
# Note: No special logic is needed for weights. They are
10331033
# either nn.Parameter (auto-excluded from offload) or are
10341034
# temporary workspaces freshly created in each forward pass.
1035-
if is_cpu_offload_enabled():
1036-
saved = tensors_to_save[0]
1037-
offset = 4 if self._scale_bias else 3
1038-
if use_grouped_tensor_path:
1039-
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
1040-
grouped_x = saved[offset]
1041-
if grouped_x is not None:
1042-
mark_activation_offload(grouped_x)
1043-
else:
1044-
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
1045-
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
1046-
if live_xs:
1047-
mark_activation_offload(*live_xs)
1035+
saved = tensors_to_save[0]
1036+
offset = 4 if self._scale_bias else 3
1037+
if use_grouped_tensor_path:
1038+
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
1039+
self.mark_for_cpu_offload_if_needed(saved[offset])
1040+
else:
1041+
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
1042+
self.mark_for_cpu_offload_if_needed(saved[offset : offset + self.num_groups])
10481043

10491044
ctx.save_for_backward(*tensors_to_save[0])
10501045

transformer_engine/pytorch/ops/basic/l2normalization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212

1313
from ...torch_version import torch_version
14-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1514
from ...jit import (
1615
l2normalization_fused,
1716
l2normalization_fwd_fused,
@@ -102,8 +101,7 @@ def op_forward(
102101

103102
# Save state for backward pass
104103
if requires_grad:
105-
if is_cpu_offload_enabled():
106-
mark_activation_offload(x, rsqrt_norm)
104+
self.mark_for_cpu_offload_if_needed(x, rsqrt_norm)
107105
ctx.save_for_backward(x, rsqrt_norm)
108106

109107
return y

transformer_engine/pytorch/ops/basic/layer_norm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
1616
from ...constants import TE_DType
17-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1817
from ...export import is_in_onnx_export_mode
1918
from ...tensor import Quantizer
2019
from ...utils import (
@@ -216,8 +215,7 @@ def op_forward(
216215

217216
# Save state for backward pass
218217
if ctx.requires_grad:
219-
if is_cpu_offload_enabled():
220-
mark_activation_offload(x, means, rstdevs)
218+
self.mark_for_cpu_offload_if_needed(x, means, rstdevs)
221219
ctx.save_for_backward(x, means, rstdevs)
222220
ctx.dtype = dtype
223221

transformer_engine/pytorch/ops/basic/rmsnorm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
1616
from ...constants import TE_DType
17-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1817
from ...export import is_in_onnx_export_mode
1918
from ...tensor import Quantizer
2019
from ...utils import (
@@ -197,8 +196,7 @@ def op_forward(
197196

198197
# Save state for backward pass
199198
if ctx.requires_grad:
200-
if is_cpu_offload_enabled():
201-
mark_activation_offload(x, rstdevs)
199+
self.mark_for_cpu_offload_if_needed(x, rstdevs)
202200
ctx.save_for_backward(x, rstdevs)
203201
ctx.dtype = dtype
204202

transformer_engine/pytorch/ops/basic/swiglu.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import transformer_engine_torch as tex
1414
from ...constants import DType
15-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
1615
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
1716
from ...utils import clear_tensor_data
1817
from ..op import BasicOperation, OperationContext
@@ -127,8 +126,7 @@ def op_forward(
127126

128127
# Save state for backward pass
129128
if ctx.requires_grad:
130-
if is_cpu_offload_enabled():
131-
mark_activation_offload(input_)
129+
self.mark_for_cpu_offload_if_needed(input_)
132130
ctx.save_for_backward(input_)
133131
ctx.dtype = dtype
134132
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
@@ -311,8 +309,7 @@ def op_forward(
311309

312310
# Save state for backward pass
313311
if ctx.requires_grad:
314-
if is_cpu_offload_enabled():
315-
mark_activation_offload(x)
312+
self.mark_for_cpu_offload_if_needed(x)
316313
ctx.save_for_backward(x)
317314
ctx.dtype = dtype
318315
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
@@ -462,8 +459,7 @@ def fuser_forward(
462459
# Save state for backward pass
463460
ctx = basic_op_ctxs[0]
464461
if ctx.requires_grad:
465-
if is_cpu_offload_enabled():
466-
mark_activation_offload(input_)
462+
self.mark_for_cpu_offload_if_needed(input_)
467463
ctx.input_requires_grad = True
468464
ctx.extra_input_requires_grad = extra_input.requires_grad
469465
ctx.dtype = dtype

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414

1515
import transformer_engine_torch as tex
16-
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
16+
from ...cpu_offload import is_cpu_offload_enabled, start_offload
1717
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
1818
from ...quantization import Recipe
1919
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
@@ -170,7 +170,7 @@ def fuser_forward(
170170
basic_op_kwargs: list[dict[str, Any]],
171171
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
172172
# Get basic operations
173-
fc1_op, _, fc2_op = self.basic_ops
173+
fc1_op, activation_op, fc2_op = self.basic_ops
174174
fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs
175175

176176
# Tensor properties
@@ -726,7 +726,6 @@ def fuser_forward(
726726
# Save state for backward pass
727727
if requires_grad:
728728
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
729-
activation_op = self.basic_ops[1]
730729
cpu_offloading = is_cpu_offload_enabled()
731730
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
732731
activation_recompute_in_mlp = bool(
@@ -754,7 +753,9 @@ def fuser_forward(
754753
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
755754
]
756755
start_offload(*activation_tensors)
757-
mark_activation_offload(*activation_tensors)
756+
fc1_op.mark_for_cpu_offload_if_needed(grouped_fc1_x)
757+
activation_op.mark_for_cpu_offload_if_needed(activation_in)
758+
fc2_op.mark_for_cpu_offload_if_needed(saved_grouped_fc2_x)
758759

759760
# FC1 saved-tensor layout.
760761
# [split_sizes, base_split_offsets, split_points,

0 commit comments

Comments
 (0)