diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7c75d11e3b..2f7e19792e 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -634,6 +634,40 @@ def test_pyt_autocast( assert x.grad.dtype == model_dtype assert op.weight.grad.dtype == model_dtype + def test_activation_offloading_policy(self, monkeypatch): + """Test opt-out API for activation CPU offloading.""" + import transformer_engine.pytorch.ops.op as op_module + + calls = [] + tensor = torch.empty(1) + tensor_id = id(tensor) + op = te_ops.Identity() + + monkeypatch.setattr( + op_module, + "mark_activation_offload", + lambda *tensors: calls.append(("mark", [id(t) for t in tensors])), + ) + monkeypatch.setattr( + op_module, + "mark_not_offload", + lambda *tensors: calls.append(("skip", [id(t) for t in tensors])), + ) + monkeypatch.setattr(op_module, "is_cpu_offload_enabled", lambda: True) + + op.mark_for_cpu_offload_if_needed(tensor, None) + assert calls == [("mark", [tensor_id])] + + calls.clear() + op.set_activation_offloading(False) + op.mark_for_cpu_offload_if_needed(tensor) + assert calls == [("skip", [tensor_id])] + + calls.clear() + op.set_activation_offloading(True) + op.mark_for_cpu_offload_if_needed(tensor) + assert calls == [("mark", [tensor_id])] + class TestBasicOps: """Tests for individual operations""" diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index f4beffe90c..8d458331e5 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -13,7 +13,6 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -114,8 +113,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x) + self.mark_for_cpu_offload_if_needed(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -414,8 +412,7 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x) + self.mark_for_cpu_offload_if_needed(x) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 6b17d66fcd..0b9296b7ef 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,7 +13,6 @@ import torch from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -1049,11 +1048,13 @@ def op_forward( else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - # No special CPU offloading logic is needed for weights. saved_weight is - # either self.weight (nn.Parameter, auto-excluded from offload) or a - # workspace freshly created each forward pass. - mark_activation_offload(saved_input) + + # Activation CPU offloading + # Note: No special CPU offloading logic is needed for weights. saved_weight is + # either self.weight (nn.Parameter, auto-excluded from offload) or a + # workspace freshly created each forward pass. + self.mark_for_cpu_offload_if_needed(saved_input) + ctx.save_for_backward(saved_input, saved_weight) ctx.with_quantized_compute = with_quantized_compute and backward_override is None ctx.backward_override = backward_override diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 8850604aad..aa95a1d363 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,7 +9,6 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize @@ -71,8 +70,7 @@ def op_forward( # Save context for backward if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(mask) + self.mark_for_cpu_offload_if_needed(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 73e328c9d1..7b5b4a263a 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -23,7 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload +from ...cpu_offload import is_cpu_offload_enabled, start_offload from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer @@ -1032,19 +1032,14 @@ def fuser_forward_save_ctx( # Note: No special logic is needed for weights. They are # either nn.Parameter (auto-excluded from offload) or are # temporary workspaces freshly created in each forward pass. - if is_cpu_offload_enabled(): - saved = tensors_to_save[0] - offset = 4 if self._scale_bias else 3 - if use_grouped_tensor_path: - # Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights] - grouped_x = saved[offset] - if grouped_x is not None: - mark_activation_offload(grouped_x) - else: - # Layout: [split_sizes, None, None, (scales?), *xs, *ws] - live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None] - if live_xs: - mark_activation_offload(*live_xs) + saved = tensors_to_save[0] + offset = 4 if self._scale_bias else 3 + if use_grouped_tensor_path: + # Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights] + self.mark_for_cpu_offload_if_needed(saved[offset]) + else: + # Layout: [split_sizes, None, None, (scales?), *xs, *ws] + self.mark_for_cpu_offload_if_needed(saved[offset : offset + self.num_groups]) ctx.save_for_backward(*tensors_to_save[0]) diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index be155c9356..fb9a81180a 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -11,7 +11,6 @@ import torch from ...torch_version import torch_version -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -102,8 +101,7 @@ def op_forward( # Save state for backward pass if requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x, rsqrt_norm) + self.mark_for_cpu_offload_if_needed(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3fda5145c6..2e00740967 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,7 +14,6 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -216,8 +215,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x, means, rstdevs) + self.mark_for_cpu_offload_if_needed(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 1d8d8be971..4df086a4e1 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,7 +14,6 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -197,8 +196,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x, rstdevs) + self.mark_for_cpu_offload_if_needed(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 02f330ede3..131786a9c9 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -12,7 +12,6 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -127,8 +126,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(input_) + self.mark_for_cpu_offload_if_needed(input_) ctx.save_for_backward(input_) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -311,8 +309,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x) + self.mark_for_cpu_offload_if_needed(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -462,8 +459,7 @@ def fuser_forward( # Save state for backward pass ctx = basic_op_ctxs[0] if ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(input_) + self.mark_for_cpu_offload_if_needed(input_) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index f03ccc15b5..1df4bf164e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,7 +13,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload +from ...cpu_offload import is_cpu_offload_enabled, start_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -170,7 +170,7 @@ def fuser_forward( basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, _, fc2_op = self.basic_ops + fc1_op, activation_op, fc2_op = self.basic_ops fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties @@ -726,7 +726,6 @@ def fuser_forward( # Save state for backward pass if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) - activation_op = self.basic_ops[1] cpu_offloading = is_cpu_offload_enabled() activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( @@ -754,7 +753,9 @@ def fuser_forward( t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None ] start_offload(*activation_tensors) - mark_activation_offload(*activation_tensors) + fc1_op.mark_for_cpu_offload_if_needed(grouped_fc1_x) + activation_op.mark_for_cpu_offload_if_needed(activation_in) + fc2_op.mark_for_cpu_offload_if_needed(saved_grouped_fc2_x) # FC1 saved-tensor layout. # [split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 8df929f799..b0911b8ae1 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,7 +10,6 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias @@ -129,8 +128,7 @@ def fuser_forward( else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + linear_op.mark_for_cpu_offload_if_needed(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5376a7d264..5aed52546b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,7 +10,6 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias @@ -126,8 +125,7 @@ def fuser_forward( else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + linear_op.mark_for_cpu_offload_if_needed(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index abeb39adfa..286dc58e8e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,7 +10,6 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale @@ -107,8 +106,7 @@ def fuser_forward( else: saved_input = x_local saved_weight = w - if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + linear_op.mark_for_cpu_offload_if_needed(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 3a8ff5438d..291fdb9773 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,7 +12,6 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...quantization import FP8GlobalStateManager from ...module.base import ( @@ -354,8 +353,7 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - if is_cpu_offload_enabled(): - mark_activation_offload(x_local) + linear_op.mark_for_cpu_offload_if_needed(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 86bd60ed9c..21924ce08b 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,13 +14,22 @@ import torch from transformer_engine.common.recipe import Recipe +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ..quantization import ( FP8GlobalStateManager, QuantizerRole, RecipeState, autocast, ) -from ..tensor import Quantizer +from ..tensor import ( + GroupedTensorStorage, + QuantizedTensorStorage, + Quantizer, +) + + +# Tensor class supported by fusible operation +TensorLike = torch.Tensor | QuantizedTensorStorage | GroupedTensorStorage @dataclasses.dataclass @@ -190,6 +199,9 @@ def __init__(self) -> None: self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None self._quantizers: Optional[dict[str, list[Quantizer]]] = None + # Whether to participate when activation CPU offloading is enabled + self._activation_offloading_enabled: bool = True + @property def is_fused_op(self) -> bool: return False @@ -429,6 +441,73 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) + def set_activation_offloading(self, enabled: bool) -> None: + """Set whether to participate when activation CPU offloading is enabled. + + Offloading is controlled by an offloading context (see + ``get_cpu_offload_context``). Disabling this setting allows + this operation to opt out when the offloading context is + active, but enabling does not activate offloading outside that + context. + """ + self._activation_offloading_enabled = enabled + + def mark_for_cpu_offload_if_needed( + self, + *tensors: Iterable[Optional[TensorLike]] | TensorLike | None, + exclude: Iterable[Optional[TensorLike]] | TensorLike | None = None, + ) -> None: + """Mark tensors to include and exclude from activation CPU offloading. + + Call in op forward implementation in order to control what + tensors participate in CPU offloading. It does nothing if + offloading is not enabled (see ``get_cpu_offload_context``), + and it excludes all tensors if the op is not participating in + offloading. + """ + if not tensors and not exclude: + return + if not is_cpu_offload_enabled(): + return + + supported_classes = (torch.Tensor, QuantizedTensorStorage, GroupedTensorStorage) + + def filter_supported_and_extend( + out: list[TensorLike], + ts: Iterable[Optional[TensorLike]] | TensorLike | None, + ) -> None: + """Extend a list with objects that support CPU offloading. + + Filters out ``None`` and checks for unsupported classes. + """ + if ts is None: + return + if isinstance(ts, supported_classes): + ts = (ts,) + for t in ts: + if t is None: + continue + if not isinstance(t, supported_classes): + raise TypeError(f"{t.__class__.__name__} does not support CPU offloading.") + out.append(t) + + # Choose tensors to include and exclude from CPU offloading + include_tensors = [] + exclude_tensors = [] + is_enabled = self._activation_offloading_enabled + for t in tensors: + filter_supported_and_extend( + include_tensors if is_enabled else exclude_tensors, + t, + ) + filter_supported_and_extend(exclude_tensors, exclude) + + # Mark tensors + if include_tensors: + mark_activation_offload(*include_tensors) + if exclude_tensors: + mark_not_offload(*exclude_tensors) + @abc.abstractmethod def op_forward( self, @@ -748,6 +827,18 @@ def get_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]: return self.basic_ops[-1].get_grad_output_quantizer() + def set_activation_offloading(self, enabled: bool) -> None: + """Whether to participate when activation CPU offloading is enabled globally. + + Offloading is controlled by an offloading context (see + ``get_cpu_offload_context``). Disabling this setting allows + this operation to opt out when the offloading context is + active, but enabling does not activate offloading outside that + context. + """ + for op in self.basic_ops: + op.set_activation_offloading(enabled) + def pre_first_fuser_forward(self) -> None: for op in self.basic_ops: op.pre_first_fuser_forward()