|
13 | 13 | import torch |
14 | 14 |
|
15 | 15 | import transformer_engine_torch as tex |
16 | | -from ...cpu_offload import is_cpu_offload_enabled, start_offload |
| 16 | +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload |
17 | 17 | from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor |
18 | 18 | from ...quantization import Recipe |
19 | 19 | from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer |
@@ -170,7 +170,7 @@ def fuser_forward( |
170 | 170 | basic_op_kwargs: list[dict[str, Any]], |
171 | 171 | ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: |
172 | 172 | # Get basic operations |
173 | | - fc1_op, activation_op, fc2_op = self.basic_ops |
| 173 | + fc1_op, _, fc2_op = self.basic_ops |
174 | 174 | fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs |
175 | 175 |
|
176 | 176 | # Tensor properties |
@@ -726,6 +726,7 @@ def fuser_forward( |
726 | 726 | # Save state for backward pass |
727 | 727 | if requires_grad: |
728 | 728 | mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) |
| 729 | + activation_op = self.basic_ops[1] |
729 | 730 | cpu_offloading = is_cpu_offload_enabled() |
730 | 731 | activation_is_srelu = isinstance(activation_op, ScaledSReLU) |
731 | 732 | activation_recompute_in_mlp = bool( |
@@ -753,9 +754,7 @@ def fuser_forward( |
753 | 754 | t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None |
754 | 755 | ] |
755 | 756 | start_offload(*activation_tensors) |
756 | | - fc1_op.mark_for_cpu_offload_if_needed(grouped_fc1_x) |
757 | | - activation_op.mark_for_cpu_offload_if_needed(activation_in) |
758 | | - fc2_op.mark_for_cpu_offload_if_needed(saved_grouped_fc2_x) |
| 757 | + mark_activation_offload(*activation_tensors) |
759 | 758 |
|
760 | 759 | # FC1 saved-tensor layout. |
761 | 760 | # [split_sizes, base_split_offsets, split_points, |
|
0 commit comments