Skip to content

Commit 3fffa55

Browse files
lhb8125timmoon10claudepre-commit-ci[bot]
authored
[PyTorch] Debug CPU offloading in grouped linear and grouped MLP (#3047)
* Support selective offload for fused grouped MLP Signed-off-by: hongbinl <hongbinl@nvidia.com> * Add no_offload_activation to grouped MLP ops Signed-off-by: hongbinl <hongbinl@nvidia.com> * Use offload_activation API for activation offload control Signed-off-by: hongbinl <hongbinl@nvidia.com> * Fix CPU offloading correctness in ops layer - Revert per-module offload_activation API added in commits 376d28c and 933d64b; that belongs in a separate PR. - ops/basic/grouped_linear: add start_offload on input tensors before the GEMM, and mark_activation_offload / mark_not_offload in fuser_forward_save_ctx for both the split-quantize and grouped-tensor paths. - ops/fused/forward_grouped_mlp: remove no_offload_activation attribute lookups and the activation mark_not_offload calls that gated on them; add start_offload + mark_activation_offload for all saved activation tensors (grouped_fc1_x, activation_in, saved_grouped_fc2_x) and keep mark_not_offload only for weight tensors. Document why grouped_fc1_x is repacked into GroupedTensorStorage. - ops/basic/basic_linear: no change needed beyond the existing mark_activation_offload — unlike te.Linear there is no persistent weight cache, so the quantized weight workspace can be freely offloaded. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Construct internal grouped tensors within grouped linear and grouped MLP GroupedTensor should only be used when exposed externally. Otherwise GroupedTensorStorage has less CPU overhead. There also seems to be some issue with CPU offloading that has not yet been root-caused. Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: hongbinl <hongbinl@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0dd1af2 commit 3fffa55

5 files changed

Lines changed: 108 additions & 20 deletions

File tree

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
import transformer_engine_torch as tex
1717

1818
from transformer_engine.common.recipe import Recipe
19-
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
19+
from transformer_engine.pytorch.tensor.grouped_tensor import (
20+
GroupedTensor,
21+
GroupedTensorStorage,
22+
)
2023
from .base import (
2124
get_dummy_wgrad,
2225
quantize_weight,
@@ -135,9 +138,9 @@ def _make_grouped_tensor(
135138
base_split_offsets: torch.Tensor,
136139
last_dim: int,
137140
dtype: torch.dtype,
138-
) -> GroupedTensor:
139-
"""Wrap a packed 2D buffer as a varying-first-dimension GroupedTensor."""
140-
return GroupedTensor(
141+
) -> GroupedTensorStorage:
142+
"""Wrap a packed 2D buffer as a varying-first-dimension GroupedTensorStorage."""
143+
return GroupedTensorStorage(
141144
shape=(data.size(0), last_dim),
142145
dtype=dtype,
143146
num_tensors=num_gemms,
@@ -154,13 +157,13 @@ def _make_grouped_bias(
154157
num_gemms: int,
155158
out_features: int,
156159
dtype: torch.dtype,
157-
) -> GroupedTensor:
160+
) -> GroupedTensorStorage:
158161
"""Pack per-GEMM biases into the grouped GEMM bias format."""
159162
bias_data = torch.stack(
160163
[_GroupedLinear._maybe_dequantize(bias, dtype) for bias in biases],
161164
dim=0,
162165
).contiguous()
163-
return GroupedTensor(
166+
return GroupedTensorStorage(
164167
shape=(num_gemms, out_features),
165168
dtype=dtype,
166169
num_tensors=num_gemms,

transformer_engine/pytorch/ops/basic/basic_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,9 @@ def op_forward(
10501050
saved_input = x_local
10511051
saved_weight = w
10521052
if is_cpu_offload_enabled():
1053+
# No special CPU offloading logic is needed for weights. saved_weight is
1054+
# either self.weight (nn.Parameter, auto-excluded from offload) or a
1055+
# workspace freshly created each forward pass.
10531056
mark_activation_offload(saved_input)
10541057
ctx.save_for_backward(saved_input, saved_weight)
10551058
ctx.with_quantized_compute = with_quantized_compute and backward_override is None

transformer_engine/pytorch/ops/basic/grouped_linear.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_2X_ACC_DGRAD,
2424
_2X_ACC_WGRAD,
2525
)
26+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
2627
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
2728
from ...quantized_tensor import QuantizedTensorStorage
2829
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
@@ -783,7 +784,7 @@ def _get_grouped_weight_for_gemm(
783784
columnwise_usage: bool,
784785
with_quantized_compute: bool,
785786
dtype: torch.dtype,
786-
) -> GroupedTensor:
787+
) -> GroupedTensorStorage:
787788
"""Prepare weights for ``general_grouped_gemm_for_grouped_tensor``.
788789
Supports MXFP8/BF16/FP16 compute paths.
789790
"""
@@ -800,7 +801,7 @@ def _get_grouped_weight_for_gemm(
800801
weight_parts = weight_param.split_into_quantized_tensors()
801802
dequantized = [maybe_dequantize(w, dtype) for w in weight_parts]
802803
weight_data = torch.stack(dequantized, dim=0).contiguous()
803-
return GroupedTensor(
804+
return GroupedTensorStorage(
804805
shape=(num_groups * self.out_features, self.in_features),
805806
dtype=dtype,
806807
num_tensors=num_groups,
@@ -814,7 +815,7 @@ def _get_grouped_weight_for_gemm(
814815
if weight_param.rowwise_data.dtype == dtype:
815816
return weight_param
816817
weight_data = weight_param.rowwise_data.to(dtype=dtype)
817-
return GroupedTensor(
818+
return GroupedTensorStorage(
818819
shape=(num_groups * self.out_features, self.in_features),
819820
dtype=dtype,
820821
num_tensors=num_groups,
@@ -866,8 +867,8 @@ def _get_weight_tensors(self) -> list[torch.nn.Parameter]:
866867
def _get_grouped_bias_for_gemm(
867868
self,
868869
dtype: torch.dtype,
869-
) -> Optional[torch.Tensor]:
870-
"""Build a uniform GroupedTensor of per-group biases for the cublas
870+
) -> Optional[GroupedTensorStorage]:
871+
"""Build a uniform GroupedTensorStorage of per-group biases for the cublas
871872
grouped GEMM.
872873
873874
Each group expects a (1, out_features) bias vector. Returns ``None``
@@ -888,7 +889,7 @@ def _get_grouped_bias_for_gemm(
888889
]
889890
bias_data = torch.stack(bias_list, dim=0).contiguous()
890891

891-
return GroupedTensor(
892+
return GroupedTensorStorage(
892893
shape=(num_groups, self.out_features),
893894
dtype=dtype,
894895
num_tensors=num_groups,
@@ -1026,6 +1027,25 @@ def fuser_forward_save_ctx(
10261027
return
10271028

10281029
ctx = basic_op_ctxs[0]
1030+
1031+
# Activation CPU offloading
1032+
# Note: No special logic is needed for weights. They are
1033+
# either nn.Parameter (auto-excluded from offload) or are
1034+
# temporary workspaces freshly created in each forward pass.
1035+
if is_cpu_offload_enabled():
1036+
saved = tensors_to_save[0]
1037+
offset = 4 if self._scale_bias else 3
1038+
if use_grouped_tensor_path:
1039+
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
1040+
grouped_x = saved[offset]
1041+
if grouped_x is not None:
1042+
mark_activation_offload(grouped_x)
1043+
else:
1044+
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
1045+
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
1046+
if live_xs:
1047+
mark_activation_offload(*live_xs)
1048+
10291049
ctx.save_for_backward(*tensors_to_save[0])
10301050

10311051
num_groups = self.num_groups
@@ -1110,6 +1130,10 @@ def _fuser_forward_split_quantize(
11101130
xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
11111131
else:
11121132
xs = torch.split(x, split_sizes_int)
1133+
if is_cpu_offload_enabled():
1134+
live_xs = [t for t in xs if t is not None]
1135+
if live_xs:
1136+
start_offload(*live_xs)
11131137

11141138
# Allocate output tensor
11151139
in_shape = list(input_.size())
@@ -1205,7 +1229,7 @@ def _fuser_forward_grouped_tensor(
12051229
grouped_x = tex.group_quantize(x, input_quantizer, num_groups, split_sizes)
12061230
else:
12071231
# No quantize: wrap the contiguous high-precision buffer.
1208-
grouped_x = GroupedTensor(
1232+
grouped_x = GroupedTensorStorage(
12091233
shape=(total_tokens, self.in_features),
12101234
dtype=dtype,
12111235
num_tensors=num_groups,
@@ -1215,6 +1239,9 @@ def _fuser_forward_grouped_tensor(
12151239
tensor_offsets=base_split_offsets * self.in_features,
12161240
)
12171241

1242+
if is_cpu_offload_enabled() and grouped_x is not None:
1243+
start_offload(grouped_x)
1244+
12181245
# Build the weight GroupedTensor / list.
12191246
if self.single_grouped_weight:
12201247
# GroupedTensor
@@ -1238,7 +1265,7 @@ def _fuser_forward_grouped_tensor(
12381265
# Allocate output buffer and wrap as a GroupedTensor view.
12391266
out_shape = original_shape[:-1] + [self.out_features]
12401267
out = torch.empty(out_shape, dtype=dtype, device=device)
1241-
grouped_out = GroupedTensor(
1268+
grouped_out = GroupedTensorStorage(
12421269
shape=(total_tokens, self.out_features),
12431270
dtype=dtype,
12441271
num_tensors=num_groups,
@@ -1566,7 +1593,7 @@ def _fuser_backward_grouped_tensor(
15661593
else:
15671594
dy_2d = maybe_dequantize(dy_2d, dtype)
15681595
# Wrap BF16/FP16 buffer as a GroupedTensor for grouped gemm
1569-
grouped_dy = GroupedTensor(
1596+
grouped_dy = GroupedTensorStorage(
15701597
shape=(total_tokens, self.out_features),
15711598
dtype=dtype,
15721599
num_tensors=num_groups,
@@ -1602,7 +1629,7 @@ def _fuser_backward_grouped_tensor(
16021629
if ctx.input_requires_grad:
16031630
grad_input_shape = list(grad_output.size())[:-1] + [self.in_features]
16041631
grad_input = torch.empty(grad_input_shape, dtype=dtype, device=device)
1605-
grouped_grad_input = GroupedTensor(
1632+
grouped_grad_input = GroupedTensorStorage(
16061633
shape=(total_tokens, self.in_features),
16071634
dtype=dtype,
16081635
num_tensors=num_groups,

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414

1515
import transformer_engine_torch as tex
16+
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
1617
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
1718
from ...quantization import Recipe
1819
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
@@ -23,6 +24,7 @@
2324
mark_grouped_tensor,
2425
)
2526
from ...tensor.grouped_tensor import GroupedTensor
27+
from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage
2628
from ...tensor.mxfp8_tensor import MXFP8Quantizer
2729
from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE
2830
from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU
@@ -316,14 +318,44 @@ def fuser_forward(
316318
# Group-quantize input tensor and convert dtypes if needed
317319
fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
318320
fc1_input_quantizer.optimize_for_gemm = True
321+
fc1_input_quantizer.internal = True
319322
input_quantizer = getattr(input_, "quantizer", None)
320323
if isinstance(input_, GroupedTensor) and (
321324
isinstance(fc1_input_quantizer, MXFP8Quantizer)
322325
and isinstance(input_quantizer, MXFP8Quantizer)
323326
or isinstance(fc1_input_quantizer, NVFP4Quantizer)
324327
and isinstance(input_quantizer, NVFP4Quantizer)
325328
):
326-
grouped_fc1_x = input_
329+
# GroupedTensor is a torch.Tensor subclass, so the CPU offload
330+
# infrastructure's prepare_for_saving treats it as a plain tensor
331+
# and does not decompose it into its component data tensors. By
332+
# repacking into a GroupedTensorStorage (not a torch.Tensor), we
333+
# ensure the fuser's prepare_for_saving call correctly decomposes
334+
# the activation before save_for_backward.
335+
grouped_fc1_x = GroupedTensorStorage(
336+
shape=input_.logical_shape,
337+
dtype=input_.fake_dtype,
338+
num_tensors=input_.num_tensors,
339+
shapes=input_.tensor_shapes,
340+
quantizer=input_.quantizer,
341+
data=input_.rowwise_data,
342+
columnwise_data=input_.columnwise_data,
343+
scale_inv=input_.scale_inv,
344+
columnwise_scale_inv=input_.columnwise_scale_inv,
345+
amax=input_.amax,
346+
columnwise_amax=input_.columnwise_amax,
347+
scale=input_.scale,
348+
first_dims=input_.first_dims,
349+
last_dims=input_.last_dims,
350+
tensor_offsets=input_.tensor_offsets,
351+
offsets=input_.offsets,
352+
scale_inv_offsets=input_.scale_inv_offsets,
353+
columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets,
354+
with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales,
355+
row_scaled_nvfp4=input_.row_scaled_nvfp4,
356+
nvfp4_use_4over6=input_.nvfp4_use_4over6,
357+
nvfp4_e4m3_max=input_.nvfp4_e4m3_max,
358+
)
327359
else:
328360
fc1_x = maybe_dequantize(input_, dtype)
329361
grouped_fc1_x = _group_quantize_for_grouped_mlp(
@@ -587,7 +619,7 @@ def fuser_forward(
587619
else:
588620
fc2_out_buf = fc2_out_buf + token_bias
589621
else:
590-
fc2_out_grouped = GroupedTensor(
622+
fc2_out_grouped = GroupedTensorStorage(
591623
shape=(in_shape[0], fc2_weight_shape[0]),
592624
dtype=dtype,
593625
num_tensors=num_groups,
@@ -616,7 +648,7 @@ def fuser_forward(
616648
fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"]
617649
fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3)
618650

619-
grouped_fc2_x = GroupedTensor(
651+
grouped_fc2_x = GroupedTensorStorage(
620652
shape=(in_shape[0], fc2_weight_shape[1]),
621653
dtype=dtype,
622654
num_tensors=num_groups,
@@ -695,6 +727,7 @@ def fuser_forward(
695727
if requires_grad:
696728
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
697729
activation_op = self.basic_ops[1]
730+
cpu_offloading = is_cpu_offload_enabled()
698731
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
699732
activation_recompute_in_mlp = bool(
700733
getattr(activation_op, "activation_recompute_in_mlp", False)
@@ -716,6 +749,13 @@ def fuser_forward(
716749
grouped_fc_x.rowwise_data = None
717750
grouped_fc_x.scale_inv = None
718751

752+
if cpu_offloading:
753+
activation_tensors = [
754+
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
755+
]
756+
start_offload(*activation_tensors)
757+
mark_activation_offload(*activation_tensors)
758+
719759
# FC1 saved-tensor layout.
720760
# [split_sizes, base_split_offsets, split_points,
721761
# grouped_fc1_x, *fc1_weight_tensors]
@@ -755,7 +795,7 @@ def fuser_forward(
755795
fc2_weight_tensors = (
756796
[grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight
757797
)
758-
fc2_saved: list[Optional[torch.Tensor]] = [
798+
fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [
759799
split_sizes,
760800
base_split_offsets,
761801
split_points,

transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,21 @@ def restore_from_saved(
387387
self.tensor_offsets = tensors[9]
388388
return tensors[10:]
389389

390+
def get_data_tensors(self):
391+
"""Get tensor fields that may be saved or offloaded."""
392+
return (
393+
self.rowwise_data,
394+
self.columnwise_data,
395+
self.scale_inv,
396+
self.columnwise_scale_inv,
397+
self.amax,
398+
self.columnwise_amax,
399+
self.scale,
400+
self.first_dims,
401+
self.last_dims,
402+
self.tensor_offsets,
403+
)
404+
390405
def clear(self) -> None:
391406
"""
392407
Reset tensor data and clear all buffers.

0 commit comments

Comments
 (0)