From 943f1e0ac9366491f8f589aefef62f65c3f3c07c Mon Sep 17 00:00:00 2001 From: Jennifer Zhou Date: Tue, 12 Nov 2024 20:30:40 -0800 Subject: [PATCH 001/239] Fix an int conversion error (#1325) fix an int conversion error Signed-off-by: Jennifer Zhou --- transformer_engine/jax/csrc/extensions/activation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 9d5fb4f7b4..a2090bceba 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -264,8 +264,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act auto *output = output_buf->untyped_data(); auto act_input_dims = act_input_buf.dimensions(); - auto m = product(act_input_dims, 0, act_input_dims.size() - 2); - auto n = act_input_dims.back(); + auto m = static_cast(product(act_input_dims, 0, act_input_dims.size() - 2)); + auto n = static_cast(act_input_dims.back()); auto act_len = act_input_dims.end()[-2]; auto input_shape = std::vector{m, n}; From c0a539c6f91395fe96eede92aea168cc43de1a15 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:31:50 -0800 Subject: [PATCH 002/239] [PyTorch] Fix ONNX export bug with operation-based API (#1320) Debug ONNX export with te.Sequential ONNX export assumes that all state dict objects are tensor, even extra state. Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 0bb6f25db8..d03a83d2ca 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -505,7 +505,7 @@ def forward( basic_op_kwargs=[kwargs], ) - def get_extra_state(self) -> Optional[torch.Tensor]: + def get_extra_state(self) -> torch.Tensor: """Serialize extra state Contains metadata for FP8 casting. @@ -534,7 +534,7 @@ def get_extra_state(self) -> Optional[torch.Tensor]: self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") ) if not has_fp8_state: - return None + return torch.Tensor() def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -588,7 +588,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: def set_extra_state(self, state: Optional[torch.Tensor]) -> None: """Load extra state""" - if state is None: + if state is None or state.numel() == 0: return # Deserialize state from byte tensor From 28aa41a3fdb68dc217aab41dc784933f1e3b5c23 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:12:22 -0800 Subject: [PATCH 003/239] [PyTorch] Remove special handling for FP8 params in FP8 recipe infrastructure (#1326) * Remove manual FP8 scale update for FP8 params Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Tim Moon Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fusible_ops.py | 2 +- transformer_engine/pytorch/fp8.py | 47 +++---------------- transformer_engine/pytorch/graph.py | 2 +- transformer_engine/pytorch/module/base.py | 4 +- transformer_engine/pytorch/ops/op.py | 4 +- .../pytorch/tensor/float8_tensor.py | 27 ----------- 6 files changed, 11 insertions(+), 75 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 29829ac4ac..ec539e1f06 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -293,7 +293,7 @@ def test_fp8_scale_update( ) # Check that scaling factors match expected - w_amax_ref = max(w_vals[: step + 2]) + w_amax_ref = max(w_vals[: step + 1]) x_amax_ref = max(x_vals[: step + 1]) dy_amax_ref = max(dy_vals[: step + 1]) w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f95ba515cb..2a909dabc6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -109,8 +109,6 @@ def reset(cls) -> None: cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None @classmethod @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_weights: bool, fp8_recipe: DelayedScaling, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" + return f"{fwd_bwd_key}_{autocast_key}" @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: """Splits buffer key into relevant parts.""" - forward, fp8_weights, autocast_key = key.split("_", 2) + forward, autocast_key = key.split("_", 1) forward = forward == "forward" - fp8_weights = fp8_weights == "True" - return forward, fp8_weights, autocast_key + return forward, autocast_key @classmethod def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], - fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: """ The amax reduction process happens completely outside the FP8 modules. @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer( fp8_meta[index_in_buffer] = [] for forward in (True, False): - # This algorithm creates a two-way map with `autocast_to_fp8_params` and - # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights - # in an autocasted region and cross reference them in `float8_tensor.py` - # to perform the forward amax reduction. fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if fp8_meta_tensor_key not in fp8_meta: # Handles non-parameter FP8 modules, e.g. DPA. continue - if forward and fp8_weights is not None: - autocast_key = cls.get_unique_autocast_key( - fp8_meta["recipe"], fp8_meta["fp8_group"] - ) - fp8_weight_set = {id(w._data) for w in fp8_weights} - if autocast_key not in cls.autocast_to_fp8_params: - cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set - else: - cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[ - autocast_key - ].union(fp8_weight_set) - # Identify correct autocast key for a given param. - for w in fp8_weight_set: - cls.fp8_param_to_autocast[w] = autocast_key - - key = cls.get_key_in_buffer( - forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"] - ) + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] @@ -327,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty def reduce_and_update_fp8_tensors( cls, forward: bool = True, - fp8_weights: bool = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. - fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue - # Only skip a forward update when `fp8_weights` is explicitly set to `True` - # (inside optimizer) and the current key is not an `fp8_weight_update` key. - # For other cases, we need to reduce because of activation tensors. - # TODO(ksivaman) consider separate weight and activation fp8_tensors. - if fwd_update and fp8_weights and not fp8_weights_update: - continue if len(amax_buffer) == 0: continue @@ -434,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) + cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ed0ed1c008..c47b792a95 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -465,7 +465,7 @@ def new_fwd(*user_args, **user_kwargs): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m._get_fp8_params() + m.fp8_meta, ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 534174380f..3a15242c3a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -762,9 +762,7 @@ def prepare_forward( ) if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self._get_fp8_params() - ) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index d03a83d2ca..04a66b7942 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -19,7 +19,7 @@ FP8GlobalStateManager, get_default_fp8_recipe, ) -from ._common import canonicalize_device, is_float8_tensor +from ._common import canonicalize_device @dataclasses.dataclass @@ -379,10 +379,8 @@ def pre_forward( self.get_fp8_meta("input"), ) if self.num_fp8_scales("param"): - fp8_params = list(filter(is_float8_tensor, self.parameters())) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( self.get_fp8_meta("param"), - fp8_weights=(fp8_params if fp8_params else None), ) if self.num_fp8_scales("grad_output"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 36136292df..7ace68a222 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -74,30 +74,6 @@ def backward( return grad, None -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @@ -676,9 +652,6 @@ def quantize_( ) dst._transpose_invalid = False - # Callback hook to perform amax reduction after optimizer step - post_optimizer_step_fwd_amax_reduction(self) - return self @classmethod From d1488e7339d9d7b33b9e681c337fa2eb0b84f2ad Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Nov 2024 14:23:12 -0800 Subject: [PATCH 004/239] [PyTorch] Fix multiple calls to saved_tensors in CP attention (#1334) * Limit to one call of ctx.saved_tensors per autograd bwd Signed-off-by: Kirthi Shankar Sivamani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 35 +++++++++++++------------ 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..28c1b45ffa 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2528,12 +2528,13 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] - cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + (*saved_tensors,) = ctx.saved_tensors + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] + (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] + cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -3577,11 +3578,12 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] - cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] - out_per_step = ctx.saved_tensors[7:9] - softmax_lse_per_step = ctx.saved_tensors[9:11] - rng_states = ctx.saved_tensors[11:13] + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step @@ -4056,12 +4058,11 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - q, k, v, out = ctx.saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ - 4:8 - ] - fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] - aux_ctx_tensors = ctx.saved_tensors[10:] + (*saved_tensors,) = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] + fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] + aux_ctx_tensors = saved_tensors[10:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type From 20b0473cd1a5e999d3f5996f5c45809410f45455 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:30:41 -0800 Subject: [PATCH 005/239] [PyTorch] Activation operations (#1164) * Add activation ops Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint warnings Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Update to use QuantizedTensor Signed-off-by: Tim Moon * Respect PyTorch autograd dtype Signed-off-by: Tim Moon * Rename CastFloat8 op to Quantize Signed-off-by: Tim Moon * Add support for fused dSwiGLU-cast-transpose Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 160 +++++++ .../pytorch/cpp_extensions/transpose.py | 39 ++ transformer_engine/pytorch/csrc/extensions.h | 6 + .../pytorch/csrc/extensions/pybind.cpp | 6 + .../pytorch/csrc/extensions/transpose.cpp | 69 ++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/activation.py | 390 ++++++++++++++++++ 7 files changed, 671 insertions(+) create mode 100644 transformer_engine/pytorch/ops/basic/activation.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ec539e1f06..fd2832c1d4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1362,6 +1362,166 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Activation functions""" + + # Tensor dimensions + in_shape = list(out_shape) + if activation in ("geglu", "reglu", "swiglu"): + in_shape[-1] *= 2 + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref: torch.Tensor + if activation == "gelu": + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) + elif activation == "geglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "reglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "swiglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operation + make_op = dict( + gelu=te_ops.GELU, + relu=te_ops.ReLU, + geglu=te_ops.GEGLU, + reglu=te_ops.ReGLU, + swiglu=te_ops.SwiGLU, + )[activation] + forward = te_ops.Sequential( + make_op(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_swiglu( + self, + *, + out_shape: Iterable[int] = (16, 16), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_output: bool, + fp8_grad_input: bool, + ): + + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + fp8 = fp8_output or fp8_grad_input + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + fp8_recipe = None + if fp8_grad_input: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + y_ref.backward(dy_ref) + + # Implementation with fusible operation + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.SwiGLU(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ddc3b67e9e..188c03b27c 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -16,6 +16,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_dswiglu_cast_transpose_fused", "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ) +def fp8_dswiglu_cast_transpose_fused( + grad_output: torch.Tensor, + inp: torch.Tensor, + *, + grad_input: torch.Tensor, + grad_input_transpose: torch.Tensor, + otype: tex.DType, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> None: + """Fused SwiGLU backward + FP8 cast + FP8 transpose""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + + # Launch kernel + return tex.fused_dswiglu_cast_transpose( + grad_output, + inp, + grad_input, + grad_input_transpose, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + otype, + **fp8_scales_offsets, + ) + + def fp8_multi_cast_transpose_fused( input_list: List[torch.Tensor], fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b039bf2d1b..3b49ece4a3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -210,6 +210,12 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset = 0, + int amax_offset = 0, int scale_inv_offset = 0); + void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 39679ed669..8856553c54 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, + "Fused SwiGLU backward + FP8 cast + FP8 transpose", + py::call_guard(), py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 56f6b56769..f373cdf83a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -196,6 +196,75 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset, + int amax_offset, int scale_inv_offset) { + using namespace transformer_engine; + + // Tensor dimensions + auto outer_dim = [](const at::Tensor& tensor) -> size_t { + return tensor.numel() / tensor.size(-1); + }; + const auto M = outer_dim(grad_output); + const auto N = static_cast(grad_output.size(-1)); + + // Check tensor dims + NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", + grad_output.dim()); + NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, + ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, + ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", + grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", + M, ", but found ", outer_dim(grad_input)); + NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", + 2 * N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input_transpose.dim() == 2, + "Expected grad input transpose tensor to have 2 dims, but found ", + grad_input_transpose.dim()); + NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, + "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", + grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(1) == M, + "Expected grad input tensor to have outer dimension of ", M, ", but found ", + grad_input_transpose.size(1)); + + // Check tensor format + NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); + NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); + NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); + NVTE_CHECK(grad_input_transpose.is_contiguous(), + "Expected grad input transpose tensor to be contiguous"); + NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), + "Expected grad output tensor and input tensor to have same dtype"); + NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, + "Expected grad input tensor to be uint8 buffer"); + NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, + "Expected grad input transpose tensor to be uint8 buffer"); + + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + auto dy_cu = makeTransformerEngineTensor(grad_output); + auto x_cu = makeTransformerEngineTensor(input); + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); + + // Launch kernel + nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + void fused_multi_cast_transpose_base(std::vector input_list, std::vector scale_dptr_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 3dd8f64229..d6f4940c58 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,6 +4,7 @@ """Single tensor operations supported by the operation fuser.""" +from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py new file mode 100644 index 0000000000..a2e5a24a85 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -0,0 +1,390 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations for activation functions.""" + +from __future__ import annotations +import abc +from typing import Optional + +import torch + +import transformer_engine_torch +from ...constants import TE_DType +from ...cpp_extensions import ( + geglu as tex_geglu, + gelu as tex_gelu, + reglu as tex_reglu, + relu as tex_relu, + swiglu as tex_swiglu, + fp8_dswiglu_cast_transpose_fused, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import clear_tensor_data, devices_match +from ..op import BasicOperation, OperationContext + + +class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): + r"""Apply activation function + + Activation functions are either element-wise unary functions or + variants of the gated linear unit (GLU). Recall that GLU is + computed by splitting the input tensor into chunks :math:`a` and + :math:`b` along the last dimension and computing + + .. math:: + \text{GLU}(a,b) = \sigma(a) * b + + .. warning:: + + Transformer Engine gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + @abc.abstractmethod + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + """Forward implementation + + Implementation from transformer_engine.pytorch.cpp_extensions. + + """ + + @abc.abstractmethod + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + """Backward implementation + + Implementation from transformer_engine_torch. + + """ + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = input_ + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if x.device.type != "cuda": + x = x.cuda() + if x.dtype != dtype: + x = x.to(dtype=dtype) + if not x.is_contiguous(): + x = x.contiguous() + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + with_fp8_output = False + output_fp8_meta = None + output_dtype = TE_DType[dtype] + output_fp8_scale_inv = None + if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: + with_fp8_output = True + fp8_meta = next_op.get_fp8_meta("input") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + output_fp8_meta = fp8_meta[fp8_meta_key] + output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + + # Launch kernel + y = self._activation_forward_impl( + x, + output_fp8_meta, + 0, + output_dtype, + scale_inv=output_fp8_scale_inv, + ) + + # Check output tensor + if y.dim() != x.dim(): + y = y.reshape(list(x.shape[:-1]) + [-1]) + if with_fp8_output: + y = Float8Tensor( + data=y, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=output_dtype, + fp8_scale_inv=output_fp8_scale_inv, + dtype=dtype, + ) + + # Save state for backward pass + ctx.save_for_backward(x) + ctx.fp8_enabled = fp8_enabled + ctx.prev_op = prev_op + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: + dy = dy.to(device=x.device, dtype=x.dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Launch kernel + dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) + + # Check grad input tensor + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Clear input tensor if possible + if ctx.prev_op is not None: + clear_tensor_data(x) + + return dx, () + + +class GELU(_ActivationOperation): + r"""Gaussian Error Linear Unit + + This computes the "tanh" approximation to GELU: + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + See `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_gelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgelu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified linear unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_relu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.drelu(*args, **kwargs) + + +class GEGLU(_ActivationOperation): + r"""Gaussian error gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_geglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgeglu(*args, **kwargs) + + +class ReGLU(_ActivationOperation): + r"""Rectified gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{ReGLU}(a,b) = \max(a,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_reglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dreglu(*args, **kwargs) + + +class SwiGLU(_ActivationOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer`__ + and `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_swiglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dswiglu(*args, **kwargs) + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Tensor attributes + dtype = x.dtype + device = x.device + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, device) or dy.dtype != dtype: + dy = dy.to(device=device, dtype=dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Check if FP8 is enabled + with_fp8_grad_input = False + grad_input_fp8_meta = None + grad_input_dtype = TE_DType[dtype] + grad_input_fp8_scale_inv = None + if ( + ctx.fp8_enabled + and ctx.prev_op is not None + and ctx.prev_op.num_fp8_scales("grad_output") > 0 + ): + with_fp8_grad_input = True + fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + grad_input_fp8_meta = fp8_meta[fp8_meta_key] + grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + + # Launch kernel + if with_fp8_grad_input: + # Fused with FP8 cast-transpose + input_dims = x.size() + flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] + flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] + dx = torch.empty(input_dims, dtype=torch.uint8, device=device) + dx_t = torch.empty( + (flat_input_dims[1], flat_input_dims[0]), + dtype=torch.uint8, + device=device, + ) + fp8_dswiglu_cast_transpose_fused( + dy.reshape(flat_output_dims), + x.reshape(flat_input_dims), + grad_input=dx.reshape(flat_input_dims), + grad_input_transpose=dx_t, + otype=grad_input_dtype, + fp8_meta=grad_input_fp8_meta, + fp8_meta_index=0, + scale_inv=grad_input_fp8_scale_inv, + ) + dx = Float8Tensor( + data=dx, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=grad_input_dtype, + fp8_scale_inv=grad_input_fp8_scale_inv, + dtype=dtype, + ) + dx._transpose = dx_t + dx._transpose_invalid = False + else: + # Standard impl + dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) + # # Clear input tensor if possible + # if ctx.prev_op is not None: + # clear_tensor_data(x) + + return dx, () From 89e3292fcd482cb11b299c23a3933e6f6c3ae281 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 15 Nov 2024 15:06:53 -0800 Subject: [PATCH 006/239] Changed VERSION to 1.14.0.dev Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 28444e84a9..809a0327d8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.13.0.dev0 +1.14.0.dev0 From 994f19d05d2541854997e6c0ecb3f904af0e566c Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi <939877+kmaehashi@users.noreply.github.com> Date: Sat, 16 Nov 2024 08:59:54 +0900 Subject: [PATCH 007/239] Use `CMAKE_CURRENT_SOURCE_DIR` instead of `CMAKE_SOURCE_DIR` (#1333) use CMAKE_CURRENT_SOURCE_DIR instead of CMAKE_SOURCE_DIR Signed-off-by: Kenichi Maehashi --- transformer_engine/common/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3784689f9a..e32011367b 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -30,14 +30,14 @@ endif() # cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR - "${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") message(FATAL_ERROR - "Could not find cuDNN frontend API. " + "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " "Try running 'git submodule update --init --recursive' " "within the Transformer Engine source.") endif() -include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) From b495120eff1dd5215a020fbc3390f46e4a242491 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 20 Nov 2024 10:37:54 -0800 Subject: [PATCH 008/239] [PyTorch] Fix GQA error message (#1328) * fix GQA error message Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 28c1b45ffa..8159f20e90 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7952,7 +7952,10 @@ def forward( assert ( key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads!" + ) assert qkv_format in [ "sbhd", "bshd", From 6b9876879825bd2e6c8db7272f3d7f401f5563d3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:47:12 -0800 Subject: [PATCH 009/239] [PyTorch] Integration test for Megatron-LM (#1329) * Handle deprecated `hidden_size` arg in norm modules Signed-off-by: Tim Moon * Support initializing norm ops on CPU Signed-off-by: Tim Moon * Add integration test for Megatron-LM Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename Mcore integration test Signed-off-by: Tim Moon * Handle case in RMSNorm where hidden dim is not provided Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L1_pytorch_mcore_integration/test.sh | 58 +++++++++++++++++ .../pytorch/module/layernorm.py | 19 +++++- transformer_engine/pytorch/module/rmsnorm.py | 19 +++++- .../pytorch/ops/basic/layer_norm.py | 65 +++++++++++-------- .../pytorch/ops/basic/rmsnorm.py | 53 +++++++++------ transformer_engine/pytorch/ops/fuser.py | 6 +- 6 files changed, 168 insertions(+), 52 deletions(-) create mode 100644 qa/L1_pytorch_mcore_integration/test.sh diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh new file mode 100644 index 0000000000..01c9e14eb1 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 2048 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json +--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt +--transformer-impl transformer_engine +--fp8-format hybrid +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 32142cf48c..b42079d299 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index f3651ecc19..bd7db1f775 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 99c9c493db..710f838581 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -84,28 +89,23 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed dtype = canonicalize_dtype(dtype) weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) bias = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -143,17 +143,18 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight bias = self.bias - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) - if bias.device.type != "cuda": - bias = torch.empty_like(bias, device=self.device) - else: - bias = bias.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) # Initialize values if self.zero_centered_gamma: @@ -184,17 +185,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) @@ -266,6 +271,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, means, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -282,9 +288,12 @@ def op_backward( # Saved tensors from forward pass x, means, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -312,6 +321,6 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) - grad_bias = reshape(db, self._shape) + grad_weight = reshape(dw, weight_dims) + grad_bias = reshape(db, weight_dims) return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 4f0e2ddc22..84f05ce713 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -83,22 +88,17 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=canonicalize_dtype(dtype), ) weight = torch.nn.Parameter(weight) @@ -133,12 +133,15 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values if self.zero_centered_gamma: @@ -165,17 +168,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) if isinstance(x, QuantizedTensor): @@ -241,6 +248,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -257,9 +265,12 @@ def op_backward( # Saved tensors from forward pass x, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -285,5 +296,5 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) + grad_weight = reshape(dw, weight_dims) return grad_input, (grad_weight,) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6fcb435e5c..8b2a04cff8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -135,7 +135,11 @@ def forward( requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx].requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) + if requires_grad != x.requires_grad: + if requires_grad: + x.requires_grad_() + else: + x = x.detach() # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] From 8952bc41a2e32a47bba6d30f523e15678f1f1326 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 21 Nov 2024 18:15:07 -0800 Subject: [PATCH 010/239] [Core] Add function to convert container to string (#1342) * Add helper function to convert C++ container to string Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/cpp/util/test_string.cpp | 10 +++++++++ transformer_engine/common/util/string.h | 27 +++++++++++++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/cpp/util/test_string.cpp b/tests/cpp/util/test_string.cpp index 531994aff8..14c1cc11f3 100644 --- a/tests/cpp/util/test_string.cpp +++ b/tests/cpp/util/test_string.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include @@ -57,6 +58,12 @@ TEST(UtilTest, ToStringLike) { // to_string_like EXPECT_EQ(std::stof(to_string_like(-2.5f)), -2.5f); EXPECT_EQ(std::stod(to_string_like(2.25)), 2.25); EXPECT_EQ(std::stod(to_string_like(-4.5)), -4.5); + + // Container types + EXPECT_EQ(to_string_like(std::vector{-3,1,-4}), "(-3,1,-4)"); + EXPECT_EQ(to_string_like(std::vector{"Accept", "no", "substitutes", ".", + "Buy", "N", "V", "IDIA"}), + "(Accept,no,substitutes,.,Buy,N,V,IDIA)"); } TEST(UtilTest, ConcatStringsTest) { // concat_strings @@ -88,6 +95,9 @@ TEST(UtilTest, ConcatStringsTest) { // concat_strings EXPECT_EQ(std::stof(concat_strings(6.5f)), 6.5f); EXPECT_EQ(std::stod(concat_strings("-", 4.25)), -4.25); EXPECT_EQ(std::stod(concat_strings(8.5)), 8.5); + + // Container types + EXPECT_EQ(concat_strings("vector ", std::vector{1,-2,3}), "vector (1,-2,3)"); } TEST(UtilTest, RegexReplaceTest) { // regex_replace diff --git a/transformer_engine/common/util/string.h b/transformer_engine/common/util/string.h index c0a2aa1077..3b0db02809 100644 --- a/transformer_engine/common/util/string.h +++ b/transformer_engine/common/util/string.h @@ -13,15 +13,34 @@ namespace transformer_engine { -/*! \brief Convert to C-style or C++-style string */ +inline const std::string &to_string_like(const std::string &val) noexcept { return val; } + +constexpr const char *to_string_like(const char *val) noexcept { return val; } + +/* \brief Convert arithmetic type to string */ template ::value>::type> inline std::string to_string_like(const T &val) { return std::to_string(val); } -inline const std::string &to_string_like(const std::string &val) noexcept { return val; } - -constexpr const char *to_string_like(const char *val) noexcept { return val; } +/* \brief Convert container to string */ +template ::value>::type, + typename = decltype(std::declval().begin())> +inline std::string to_string_like(const T &container) { + std::string str; + str.reserve(1024); // Assume strings are <1 KB + str += "("; + bool first = true; + for (const auto &val : container) { + if (!first) { + str += ","; + } + str += to_string_like(val); + first = false; + } + str += ")"; + return str; +} /*! \brief Convert arguments to strings and concatenate */ template From ae393e81ea5816bc0e53af47cf49acc588394ba8 Mon Sep 17 00:00:00 2001 From: buptzyb Date: Mon, 25 Nov 2024 16:01:57 +0800 Subject: [PATCH 011/239] Support CUDA Graph for MoE models (#1233) * Align RNG tracker with megatron Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Fix module_params order and warmup bug in cudagraph Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Add fp8_group argument and fix fp8 accuracy issue for cudagraph Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Add TE modules and weights filters to support MoE models Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Revert self.fp8 Signed-off-by: Robin Zhang * Use hooks to filter module params Signed-off-by: Robin Zhang * Filter all TE modules in hooks Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Format code Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update graph.py Signed-off-by: Xin Yao * Revert CudaRNGStatesTracker Signed-off-by: Robin Zhang * Format Update Signed-off-by: Yifei Song * Revert "Use hooks to filter module params" This reverts commit 73a22e2e8bcf43ec84c23bc844b8d16d06626e26. Signed-off-by: Yifei Song * Remove filtering module params Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Signed-off-by: Xin Yao Signed-off-by: Yifei Song Co-authored-by: Yifei Song Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/fp8.py | 12 +-- transformer_engine/pytorch/graph.py | 86 +++++++++++++++++-- .../pytorch/module/layernorm_linear.py | 5 +- .../pytorch/module/layernorm_mlp.py | 5 +- transformer_engine/pytorch/module/linear.py | 6 +- 5 files changed, 97 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 2a909dabc6..15f20c81e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -442,16 +442,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) + fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"]) @contextmanager diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index c47b792a95..6c33cc72b9 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -12,6 +12,7 @@ from torch._C import _graph_pool_handle from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -173,11 +174,14 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): - per_callable_module_params.append( - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - ) + for m_chunk in range(num_model_chunks): + for _ in range(num_microbatches): + for l_no in range(num_layers): + per_callable_module_params.append( + tuple(callables[m_chunk * num_layers + l_no].parameters()) + if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) + else () + ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] @@ -201,13 +205,55 @@ def _make_graphed_callables( # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # from ending up in any captures. torch.cuda.synchronize() - with torch.cuda.stream(torch.cuda.Stream()): + + # Get warmup func and func_idx. + warmup_func_idx = [] + warmup_func = [] + if _order is None: for func_idx, func in enumerate(callables): + warmup_func_idx.append(func_idx) + warmup_func.append(func) + else: + fwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + m_chunk = c_id - 1 + for l_no in range(num_layers): + func = callables[m_chunk * num_layers + l_no] + func_idx = (m_chunk * num_microbatches * num_layers) + ( + fwd_idx[m_chunk] * num_layers + l_no + ) + warmup_func_idx.append(func_idx) + warmup_func.append(func) + fwd_idx[m_chunk] += 1 + assert len(warmup_func) == len( + sample_args + ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." + assert len(warmup_func_idx) == len( + set(warmup_func_idx) + ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + + # Filter the TE modules that cudagraph can access. + visited_te_modules = set() + + def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument + if isinstance(module, TransformerEngineBaseModule): + visited_te_modules.add(module) + + # Run warmup and do the above filtering. + with torch.cuda.stream(torch.cuda.Stream()): + for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): + hooks = [] + for module in func.modules(): + hook = module.register_forward_hook(hook_fn) + hooks.append(hook) outputs, _ = _tree_flatten(func(*args, **kwargs)) + for hook in hooks: + hook.remove() grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -216,6 +262,11 @@ def _make_graphed_callables( allow_unused=allow_unused_input, ) del outputs, grad_inputs + # The following code is added specifically for MCore's special requirements, + # aimed at preventing warmup from altering the control flow. + for module in func.modules(): + if hasattr(module, "is_first_microbatch"): + module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -462,6 +513,19 @@ def new_fwd(*user_args, **user_kwargs): isinstance(m, TransformerEngineBaseModule) and FP8GlobalStateManager.is_fp8_enabled() ): + if m not in visited_te_modules: + # Only Set the FP8 meta for the modules included by forward + continue + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + from transformer_engine.pytorch.attention import DotProductAttention + + if ( + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + ): + # Don't need to update FP8 meta for non-FP8 DPA + continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( @@ -538,6 +602,7 @@ def make_graphed_callables( fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, @@ -579,6 +644,9 @@ def make_graphed_callables( using a higher precision. fp8_recipe: recipe.DelayedScaling, default = `None` recipe used for FP8 training. + fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. fp8_weight_caching: bool, default = `False` Whether or not to cache FP8 weights across microbatches. if set to `True`, the `is_first_microbatch` boolean argument must be passed into the forward @@ -607,7 +675,11 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True + enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=True, ): outputs = old_forward(*args, **kwargs) return outputs diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fbf1b97704..92b37fcb07 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1152,7 +1152,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 64e8c9ce36..1a651474bf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1484,7 +1484,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fed467210..9492725f56 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -938,8 +938,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False From 60ce21f49a93b983f2499a6807e0087d12e4d7f4 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Mon, 25 Nov 2024 08:43:49 -0600 Subject: [PATCH 012/239] [Common] Moved framework agnostic THD kernels to common. (#1339) Moved framework agnostic THD kernels to common. --------- Signed-off-by: Michael Goldfarb --- .github/workflows/build.yml | 1 + transformer_engine/common/CMakeLists.txt | 1 + .../common/fused_attn/thd_utils.cu | 76 +++++ .../common/fused_attn/thd_utils.h | 249 ++++++++++++++++ .../pytorch/csrc/extensions/attention.cu | 268 +----------------- 5 files changed, 330 insertions(+), 265 deletions(-) create mode 100644 transformer_engine/common/fused_attn/thd_utils.cu create mode 100644 transformer_engine/common/fused_attn/thd_utils.h diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7039d38cf5..b5b262baff 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -70,6 +70,7 @@ jobs: run: pip install . -v env: NVTE_FRAMEWORK: jax + MAX_JOBS: 1 - name: 'Sanity check' run: python tests/jax/test_sanity_import.py paddle: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e32011367b..759c1c19ae 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -61,6 +61,7 @@ list(APPEND transformer_engine_SOURCES activation/swiglu.cu fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp + fused_attn/thd_utils.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu layer_norm/ln_api.cpp diff --git a/transformer_engine/common/fused_attn/thd_utils.cu b/transformer_engine/common/fused_attn/thd_utils.cu new file mode 100644 index 0000000000..a1e353be71 --- /dev/null +++ b/transformer_engine/common/fused_attn/thd_utils.cu @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../cudnn_utils.h" +#include "thd_utils.h" + +namespace transformer_engine { +namespace fused_attn { + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size * 2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); + + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/common/fused_attn/thd_utils.h new file mode 100644 index 0000000000..c9a62727e6 --- /dev/null +++ b/transformer_engine/common/fused_attn/thd_utils.h @@ -0,0 +1,249 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ + +#include +#include + +namespace transformer_engine { +namespace fused_attn { + +/*************************************************************************************************** + * Support THD format for Context Parallel: Binary search an array for a target value + **************************************************************************************************/ + +__forceinline__ __device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank); + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token); + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +struct LseCorrectionFunctor { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +struct ReadLseFunctor { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, + int num_heads, int total_tokens) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * total_tokens / 2 + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + idx = row * total_tokens + col + seq_len; + half_idx = row * total_tokens / 2 + col; + } + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, + float *lse_per_step, int *cu_seqlens, int batch, + int num_heads, int dim_per_head, int lse_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; + } + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +struct EmptyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + +#pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +template +__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, + int batch, int hidden_size, int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 8088a2b8f1..d03a10ced3 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,8 +4,11 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/fused_attn/thd_utils.h" #include "extensions.h" +using namespace transformer_engine::fused_attn; + constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -1359,64 +1362,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -/*************************************************************************************************** - * Support THD format for Context Parallel: Binary search - **************************************************************************************************/ - -__forceinline__ __device__ int binary_search(int target, int *array, int len) { - int left = 1, right = len - 1; - while (left < right) { - int mid = (left + right) / 2; - if (array[mid] <= target) { - left = mid + 1; - } else { - right = mid; - } - } - return left - 1; -} - /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ -__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, - int hidden_size_in_bytes, int half_idx, - int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; - int laneid = threadIdx.x % 32; - int num_warps = (blockDim.x * gridDim.x) / 32; - int num_total_tokens = cu_seqlens_s[batch]; - int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; - half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); - tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); - - for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { - int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); - - size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; - float4 *cur_half_token = - reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); - - offset_in_bytes = - (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; - float4 *cur_token = - reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); - - for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { - cur_half_token[idx] = cur_token[idx]; - } - } -} - at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, int half_idx) { NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); @@ -1464,51 +1413,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s * Support THD format for Context Parallel: softmax_lse related operations **************************************************************************************************/ -template -__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int total_tokens) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - int num_total_tokens = cu_seqlens_s[batch]; - - for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, half_idx; - if constexpr (lse_packed) { - idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; - half_idx = head_id * total_tokens / 2 + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - - idx = row * total_tokens + col + seq_len; - half_idx = row * total_tokens / 2 + col; - } - - Functor::run(lse, half_lse, idx, half_idx); - } - } -} - -struct LseCorrectionFunctor { - __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, - size_t half_idx) { - double val = lse[idx]; - float val_per_step = half_lse[half_idx]; - double max_scale = max(val, val_per_step); - double min_scale = min(val, val_per_step); - lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); - } -}; - void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); @@ -1559,13 +1463,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st } } -struct ReadLseFunctor { - __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, - size_t half_idx) { - half_lse[half_idx] = lse[idx]; - } -}; - at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); @@ -1620,59 +1517,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template -__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, - float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int lse_seqlen) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); - } - __syncthreads(); - - int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; - int lane_id = threadIdx.x % tile_size; - int num_tiles = (blockDim.x * gridDim.x) / tile_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); - - for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, idx_per_step; - - if constexpr (lse_packed) { - idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * lse_seqlen + col + seq_len * only_second_half; - idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; - } - float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); - - idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx = (idx * num_heads + head_id) * dim_per_head; - idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; - dtype *cur_out = out + idx; - dtype *cur_out_per_step = out_per_step + idx_per_step; - - for (int j = lane_id; j < num_loops_per_head; j += tile_size) { - float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; - float4 data = reinterpret_cast(cur_out)[j]; - dtype *p_per_step = reinterpret_cast(&data_per_step); - dtype *p = reinterpret_cast(&data); - for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { - p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); - } - reinterpret_cast(cur_out)[j] = data; - } - } - } -} - template static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, @@ -1773,87 +1617,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at * Support THD format for Context Parallel: Gradients correction in backward **************************************************************************************************/ -template -__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, - int batch, int hidden_size, int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - if constexpr (functor_idx < 2) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } else { - cu_seqlens_s[i] = cu_seqlens[i]; - } - } - __syncthreads(); - - int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; - int lane_id = threadIdx.x % group_size; - int num_groups = (blockDim.x * gridDim.x) / group_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size; - if constexpr (functor_idx < 2) { - grad_per_step = grad_per_step + offset / 2 * blockIdx.y; - } else { - grad_per_step = grad_per_step + offset * blockIdx.y; - } - grad = grad + offset * blockIdx.y; - - for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - - int token_offset; - bool is_first_half; - if constexpr (functor_idx < 2) { - token_offset = cu_seqlens_s[seq_id + functor_idx]; - is_first_half = (functor_idx == 0); - } else { - token_offset = 0; - int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); - } - - dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; - dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; - for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { - if (is_first_half) { - Functor_0::run(token, token_per_step, idx); - } else { - Functor_1::run(token, token_per_step, idx); - } - } - } -} - -struct EmptyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} -}; - -struct CopyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { - reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; - } -}; - -template -struct AddFunctor { - __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d_ = reinterpret_cast(token)[idx]; - dtype *p_ = reinterpret_cast(&d_); - - float4 d = reinterpret_cast(token_per_step)[idx]; - dtype *p = reinterpret_cast(&d); - -#pragma unroll - for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p_[i] += p[i]; - } - - reinterpret_cast(token)[idx] = d_; - } -}; - template static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens) { @@ -1945,31 +1708,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ -__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, - int total_tokens, int world_size, int rank) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - int seqlen = cu_seqlens[i]; - // Currently we assume that each sequence length is divisible by (world_size*2) since we have - // to distribute each sequence evenly to different GPUs. - assert(seqlen % (world_size * 2) == 0); - cu_seqlens_s[i] = seqlen / world_size; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - - for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; - index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; - output[token_id] = index; - } -} - at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); From a132ac499d3b388b6fe658dfda03829a06edc3c3 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Wed, 27 Nov 2024 09:35:33 -0800 Subject: [PATCH 013/239] Fix cuda graph capture for grouped gemm (#1345) * retain_graph=True for grouped gemm Signed-off-by: Xiaowei Ren * remove an unnecessary retain_graph=True Signed-off-by: Xiaowei Ren * make retain_graph in graph capture configurable Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/graph.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6c33cc72b9..f44500f7f2 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -64,6 +64,7 @@ def _make_graphed_callables( sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -320,6 +321,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs @@ -371,6 +373,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that @@ -606,6 +609,7 @@ def make_graphed_callables( fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -632,6 +636,8 @@ def make_graphed_callables( pool: (tuple of) int, default = `None`, optional An instance returned from function `torch.cuda.graph_pool_handle` that hints this graph may share memory with the indicated pool. + retain_graph_in_backward: bool, default = `False` + Whether to set retain_graph=True in backward graph capture. FP8-related parameters ---------------------- @@ -716,6 +722,7 @@ def forward_func(*args, **kwargs): sample_kwargs=sample_kwargs, _order=_order, pool=pool, + retain_graph_in_backward=retain_graph_in_backward, ) # Ensures warmup does not affect numerics for ops such as dropout. From 09519718887e78d62156bf55589590b089e76797 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:07:37 -0800 Subject: [PATCH 014/239] Update list of CI users (#1340) * Update list of CI users Signed-off-by: Tim Moon * Update list of CI users Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c2317c6509..586abd0541 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -40,6 +40,8 @@ jobs: || github.actor == 'vasunvidia' || github.actor == 'erhoo82' || github.actor == 'kocchop' + || github.actor == 'youngeunkwon0405' + || github.actor == 'KshitijLakhani' ) steps: - name: Check if comment is issued by authorized person From 64126aa8c469b2a97ace01f925f3d5786d5fd1bb Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Mon, 2 Dec 2024 12:26:48 -0800 Subject: [PATCH 015/239] Improving communication overlap for the case of multi kernel queue usage (#1308) * draft implementation Signed-off-by: Youngeun Kwon * compile error fix Signed-off-by: Youngeun Kwon * fix compile error Signed-off-by: Youngeun Kwon * remove print Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Edit comments Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * edit the bulk-overlap test case Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add version guard Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add runtime version guard Signed-off-by: Youngeun Kwon * fix the version guard Signed-off-by: Youngeun Kwon --------- Signed-off-by: Youngeun Kwon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../distributed/test_comm_gemm_overlap.py | 34 +++-- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 30 ++++- .../userbuffers/userbuffers.cu | 116 ++++++++++++++---- .../userbuffers/userbuffers.h | 18 ++- .../transformer_engine/comm_gemm_overlap.h | 2 +- 5 files changed, 157 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index ce46a72189..f81fbae1fe 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): @pytest.mark.parametrize( - "comm_type,fp8", + "comm_type, fp8, connections", [ - ("AG", False), - ("RS", False), - ("RS", True), + ("AG", False, 1), + ("RS", False, 1), + ("RS", True, 1), + ("AG", False, 8), + ("RS", False, 8), + ("RS", True, 8), + ], + ids=[ + "ALL-GATHER - BF16 - 1 connections", + "REDUCE-SCATTER - BF16 - 1 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], - ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], ) -def test_bulk_overlaps(comm_type, fp8): +def test_bulk_overlaps(comm_type, fp8, connections): """ Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + if connections == 8: + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip( + "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" + " 9.0 (HOPPER ARCH)." + ) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + else: + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) @pytest.mark.parametrize( diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index a663385b68..c6f0f870ff 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -90,6 +90,23 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0); + + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + int runtime_version = 0; + cudaRuntimeGetVersion(&runtime_version); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { + cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + } else { + _comm_launch_event = 0; + } } CommOverlapCore::~CommOverlapCore() { @@ -97,6 +114,7 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); + if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_atomic_gemm) cudaFree(_counter.dptr()); @@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); @@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm); + comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } } assert(pre_gelu_out.numel() == 0); + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch + if (_comm_launch_event) + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 26843d8107..91667958e7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS) cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#if (CUDART_VERSION >= 12030) +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3 +#else +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 +#endif + +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ + ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; + #define callranks_ag(x) \ if (ar_nvsize == x) { \ int arg1 = op - NVTE_MAX_OPS, \ @@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } } @@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con } void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } } void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, - cudaStream_t stream) { + cudaStream_t stream, cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, + comm_launch_event); } template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } } template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream) { + const int elements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, - comm, stream); + comm, stream, comm_launch_event); } template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 57e68afce0..75655ef691 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); /* each Rank input is allgather2_userbuff_inplace: offset+myrank*elements @@ -228,21 +229,26 @@ for(int slice=0;slice void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream = 0); + const int elements, communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 17ecca5ff0..1d5d192a39 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -62,7 +62,7 @@ class CommOverlapCore { bool _ubuf_scale_inv_initialized{false}; std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, From 44f6ff2e65182681f2ebaba55980fe5590b2b5e9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:01:34 -0800 Subject: [PATCH 016/239] add paged attention; test_kv_cache_accuray and test_paged_attn pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 7 +- tests/pytorch/fused_attn/test_paged_attn.py | 403 ++++++++++ tests/pytorch/test_numerics.py | 56 +- transformer_engine/common/CMakeLists.txt | 3 +- .../common/fused_attn/fused_attn.cpp | 61 +- .../fused_attn_f16_arbitrary_seqlen.cu | 273 ++++--- .../fused_attn_f16_arbitrary_seqlen.h | 4 +- .../common/fused_attn/fused_attn_fp8.cu | 2 + transformer_engine/common/fused_attn/utils.cu | 46 ++ transformer_engine/common/fused_attn/utils.h | 10 +- .../include/transformer_engine/fused_attn.h | 52 +- .../common/util/pybind_helper.h | 8 +- transformer_engine/pytorch/attention.py | 756 ++++++++++++++---- transformer_engine/pytorch/constants.py | 6 + .../pytorch/cpp_extensions/fused_attn.py | 14 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/attention.cu | 20 +- .../pytorch/kv_cache_manager_non_paged.py | 136 ++++ .../pytorch/kv_cache_manager_paged.py | 243 ++++++ 19 files changed, 1758 insertions(+), 344 deletions(-) create mode 100644 tests/pytorch/fused_attn/test_paged_attn.py create mode 100644 transformer_engine/pytorch/kv_cache_manager_non_paged.py create mode 100644 transformer_engine/pytorch/kv_cache_manager_paged.py diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index dfe3283513..fa5b74cb45 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -25,6 +25,7 @@ check_set_window_size, AttentionParams, _attention_backends, + InferenceParams, ) from transformer_engine.pytorch.constants import TE_DType import transformer_engine.pytorch.cpp_extensions as ext @@ -89,6 +90,7 @@ def __init__( num_layers: int = 1, bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + total_requests: int = 1, ): self.batch_size = batch_size self.num_heads = num_heads @@ -107,6 +109,7 @@ def __init__( self.num_layers = num_layers self.bias_shape = bias_shape self.window_size = window_size + self.total_requests = total_requests @contextmanager @@ -129,6 +132,7 @@ def _get_attention_backends( deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + inference_params: Optional[InferenceParams] = None, ) -> Tuple[List, List]: """Check if what attention backends support a model configuration""" @@ -183,6 +187,7 @@ def test(): deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, + inference_params=inference_params, ) _, _, fused_attention_backend, _, available_backends = get_attention_backend( attention_params @@ -1299,7 +1304,7 @@ def _run_transformer_layer( model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_9": ModelConfig(2, 12, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py new file mode 100644 index 0000000000..a2fa6e81a5 --- /dev/null +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -0,0 +1,403 @@ +from collections import OrderedDict +import os +import logging + +import pytest +import torch + +from torch.distributions import Exponential +from transformer_engine.pytorch.attention import ( + DotProductAttention, + InferenceParams, +) +from transformer_engine.pytorch.utils import is_bf16_compatible +from test_fused_attn import ( + ModelConfig, + reset_rng_states, + _get_attention_backends, +) +from tests.pytorch.test_numerics import assert_allclose + +# Initialize RNG state +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +_cpu_rng_state = torch.get_rng_state() +_cuda_rng_state = torch.cuda.get_rng_state() + +class Batch(object): + def __init__(self): + self.batch_size = 0 + self.seq_ids = torch.Tensor([]).to(dtype=torch.bool,device='cpu') + self.ctx_lens = torch.Tensor([]).to(dtype=torch.bool,device='cpu') + self.gen_lens = torch.Tensor([]).to(dtype=torch.bool,device='cpu') + self.total_lens = self.ctx_lens + self.gen_lens + self.expected_gen_lens = torch.Tensor([]).to(dtype=torch.bool,device='cpu') + self.finished = torch.Tensor([]).to(dtype=torch.bool,device='cpu') + self.step_lens_q = torch.Tensor([]).to(dtype=torch.int32,device='cpu') + + def copy(self): + new_batch = Batch() + new_batch.batch_size = self.batch_size + new_batch.seq_ids = self.seq_ids + new_batch.ctx_lens = self.ctx_lens + new_batch.gen_lens = self.gen_lens + new_batch.total_lens = self.total_lens + new_batch.expected_gen_lens = self.expected_gen_lens + new_batch.finished = self.finished + new_batch.step_lens_q = self.step_lens_q + return new_batch + + def print(self, logger, header='current batch:'): + logger.debug(header) + logger.debug(' {:<17s}: {}'.format('batch_size',self.batch_size)) + logger.debug(' {:<17s}: {}'.format('seq_ids',self.seq_ids.tolist())) + logger.debug(' {:<17s}: {}'.format('ctx_lens',self.ctx_lens.tolist())) + logger.debug(' {:<17s}: {}'.format('gen_lens',self.gen_lens.tolist())) + logger.debug(' {:<17s}: {}'.format('total_lens',self.total_lens.tolist())) + logger.debug(' {:<17s}: {}'.format('expected_gen_lens',self.expected_gen_lens.tolist())) + logger.debug(' {:<17s}: {}'.format('finished',self.finished.tolist())) + logger.debug(' {:<17s}: {}'.format('step_lens_q',self.step_lens_q.tolist())) + + def add_new_seqs(self, seq_ids, context_lens, expected_gen_lens): + ctx_lens = context_lens[seq_ids] + gen_lens = torch.Tensor([0] * len(seq_ids)).to(dtype=torch.int32,device='cpu') + exp_gen_lens = expected_gen_lens[seq_ids] + finished = torch.Tensor([False] * len(seq_ids)).to(dtype=torch.bool,device='cpu') + + self.batch_size = self.batch_size + len(seq_ids) + self.finished = torch.cat([self.finished, finished], dim=0) + + if len(self.seq_ids) == 0: + self.seq_ids = seq_ids + self.ctx_lens = ctx_lens + self.gen_lens = gen_lens + self.expected_gen_lens = exp_gen_lens + else: + self.seq_ids = torch.cat([self.seq_ids, seq_ids],dim=0) + self.ctx_lens = torch.cat([self.ctx_lens, ctx_lens], dim=0) + self.gen_lens = torch.cat([self.gen_lens, gen_lens], dim=0) + self.expected_gen_lens = torch.cat([self.expected_gen_lens, exp_gen_lens], dim=0) + self.total_lens = self.ctx_lens + self.gen_lens + self.step_lens_q = torch.cat([self.step_lens_q, ctx_lens], dim=0) + + def remove_finished(self): + self.finished = torch.where( + self.gen_lens - self.expected_gen_lens < 0, False, True).to( + dtype=torch.bool,device='cpu') + self.batch_size = self.finished.logical_not().sum().item() + self.seq_ids = self.seq_ids[~self.finished] + self.ctx_lens = self.ctx_lens[~self.finished] + self.gen_lens = self.gen_lens[~self.finished] + self.total_lens = self.total_lens[~self.finished] + self.expected_gen_lens = self.expected_gen_lens[~self.finished] + self.gen_lens = self.gen_lens + 1 + self.total_lens = self.total_lens + 1 + self.step_lens_q = torch.ones([self.batch_size], dtype=torch.int32, device='cpu') + +param_types = [torch.float16] +if is_bf16_compatible(): + param_types.append(torch.bfloat16) + +model_configs_infer = { + # test: b, h, hg, d, sq, skv, p, mask, bias + "infer_0": ModelConfig(4, 16, 16, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=8), + "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), + } + +qkv_formats = ['bshd', 'sbhd', 'thd'] + +def to_pretty_string(x: torch.Tensor): + return '['+','.join(['{:>3s}'.format(str(i)) for i in x.tolist()])+']' + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model", model_configs_infer.keys()) +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("is_paged", [False, True]) +@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("is_cuda_graph", [False, True]) +def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): + reset_rng_states() + logger = logging.getLogger('test_paged_attn') + + config = model_configs_infer[model] + layer_number = 1 + + inference_params_qkv_format = 'bshd' + if is_paged: + qkv_layout = "paged_kv_"+inference_params_qkv_format+'_2'+inference_params_qkv_format + else: + qkv_layout = '_'.join([inference_params_qkv_format]*3) + available_backends, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=False, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if backend == "FlashAttention" and not flash_attn_supported: + pytest.skip("FlashAttention backend is not supported") + if backend == "FusedAttention" and not fused_attn_supported: + pytest.skip("FusedAttention backend is not supported") + if backend == "UnfusedAttention" and not unfused_attn_supported: + pytest.skip("UnfusedAttention backend is not supported") + + os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention")) + os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention")) + os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) + + total_requests = config.total_requests + # max_batch_size may be smaller than total_requests + max_batch_size = config.batch_size + # maximum KV length (context + generation) + max_seqlen_kv = config.max_seqlen_kv + + # mask type for inference + attn_mask_type = "padding" + + # page size in number of tokens (k cache and v cache are separate) + page_size = 256 if backend == "FlashAttention" else 16 + + max_seqlen_kv_roundup = max_seqlen_kv + if is_paged: + max_seqlen_kv_roundup = int((max_seqlen_kv + page_size - 1)//page_size * page_size) + else: + max_seqlen_kv_roundup = int((max_seqlen_kv + 63)//64 * 64) + cache_size = max_batch_size * max_seqlen_kv_roundup + total_num_pages = int(cache_size / page_size) + + context_ratio = 0.25 + gen_ratio = 1 - context_ratio + max_context_len = int(max_seqlen_kv * context_ratio) + max_gen_len = int(max_seqlen_kv * gen_ratio) + + # context lengths in Uniform distribution + context_lens = torch.randint(1, max_context_len, [total_requests], dtype=torch.int32, device='cpu') + # generation lengths in Exponential distribution + gen_dist = Exponential(1/max_gen_len) + gen_lens = gen_dist.sample((total_requests,)) + gen_lens = torch.where(gen_lens>max_gen_len, max_gen_len, gen_lens).to(dtype=torch.int32, device='cpu') + # arrival times in Poisson distribution + rate = torch.randint(1, max_batch_size, [1]).item() + interval_dist = Exponential(rate) + arrival_intervals = interval_dist.sample((total_requests,)) + arrival_times = torch.cumsum(arrival_intervals,dim=0).to(dtype=torch.int32, device='cpu') + last_arrival = arrival_times.max().item() + + logger.info("Simulation:") + logger.info(f" total num of requests: {total_requests}") + logger.info(f" k/v cache size: {cache_size} tokens") + logger.info(f" is_paged: {is_paged}") + logger.info(f" dtype: {dtype}") + if not is_paged: + logger.info(f" max_batch_size: {max_batch_size}") + logger.info(f" max_seqlen_kv: {max_seqlen_kv}") + else: + logger.info(f" total_num_pages: {total_num_pages}") + logger.info(f" page_size: {page_size}") + logger.info(f" context_lens: {to_pretty_string(context_lens)}") + logger.info(f" expected_gen_lens: {to_pretty_string(gen_lens)}") + logger.info(f" arrival_times: {to_pretty_string(arrival_times)}") + + model = ( + DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=0.0, + attn_mask_type="causal", + qkv_format='bshd', + ) + .cuda() + .eval() + ) + + q = 0.1 * torch.randn( + (total_requests, max_seqlen_kv, config.num_heads, config.head_dim_qk), + dtype=dtype, device="cuda") + k = 0.1 * torch.randn( + (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk), + dtype=dtype, device="cuda") + v = 0.1 * torch.randn( + (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_v), + dtype=dtype, device="cuda") + + logger.info("") + logger.info("=== Generating all tokens at once ===") + request_delays = torch.zeros([total_requests],dtype=torch.int32,device='cpu') + full_output = model( + query_layer=q, + key_layer=k, + value_layer=v, + qkv_format='bshd', + attn_mask_type="causal", + ) + + t = 1 + logger.info(f"total steps taken: {t}") + logger.info(f"arrival_times: {to_pretty_string(arrival_times)}") + logger.info(f"gen_lens: {to_pretty_string(gen_lens)}") + logger.info(f"serving_times: {to_pretty_string(arrival_times + request_delays)}") + + logger.info("") + logger.info("=== Generating one token at a time ===") + inference_params = InferenceParams( + max_batch_size=max_batch_size, + max_seqlen_kv=max_seqlen_kv_roundup, + num_heads_kv=config.num_gqa_groups, + head_dim_k=config.head_dim_qk, + head_dim_v=config.head_dim_v, + dtype=dtype, + is_paged=is_paged, + page_size=page_size, + total_num_pages=total_num_pages, + is_cuda_graph=is_cuda_graph, + num_heads_q=config.num_heads, + head_dim_q=config.head_dim_qk, + ) + inference_params.allocate_memory(layer_number) + inference_params.print() + + request_delays = torch.zeros([total_requests],dtype=torch.int32,device='cpu') + t = 0 + prev = Batch() + delayed_seq_ids = torch.Tensor().to(dtype=torch.int32,device='cpu') + while True: + logger.debug(f"time step {t}") + cur = prev.copy() + if t != 0: + cur.remove_finished() + if inference_params.is_paged: + inference_params.cache_manager.print_cache() + + arrived_seq_ids = torch.where(arrival_times == t, True, False).nonzero().view(-1) + if inference_params.is_paged: + allowed_num_new_seqs = max_batch_size - cur.batch_size + else: + allowed_num_new_seqs = 0 if cur.batch_size > 0 else max_batch_size + queuing_seq_ids = torch.cat([delayed_seq_ids, arrived_seq_ids],dim=0) + logger.debug(f"arrived seq_ids: {to_pretty_string(arrived_seq_ids)}") + logger.debug(f"previously delayed seq_ids: {to_pretty_string(delayed_seq_ids)}") + logger.debug(f"allowed num of new sequences: {allowed_num_new_seqs}") + if len(queuing_seq_ids) > allowed_num_new_seqs: + seq_ids = queuing_seq_ids[:allowed_num_new_seqs] + delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:] + request_delays[delayed_seq_ids.tolist()] += 1 + else: + seq_ids = queuing_seq_ids + delayed_seq_ids = torch.Tensor().to(dtype=torch.int32) + cur.add_new_seqs(seq_ids, context_lens, gen_lens) + cur.print(logger) + if inference_params.is_paged: + inference_params.cache_manager.print_cache() + + if cur.batch_size == 0: + # all sequences are finished + if t > last_arrival: + break + # not finished; run next iteration + else: + prev = cur.copy() + t += 1 + continue + + if not is_cuda_graph: + max_seqlen_q_infer = int((cur.step_lens_q.max().item() + 63)//64 * 64) + else: + max_seqlen_q_infer = max_seqlen_kv_roundup + + # create incremental input + if qkv_format == 'thd': + incremental_q = torch.Tensor().to(dtype=dtype, device='cuda') + incremental_k = torch.Tensor().to(dtype=dtype, device='cuda') + incremental_v = torch.Tensor().to(dtype=dtype, device='cuda') + for i,seq in enumerate(cur.seq_ids): + start = (cur.total_lens[i]-cur.step_lens_q[i]).item() + end = cur.total_lens[i].item() + incremental_q = torch.cat([incremental_q, + q[seq, start:end, :, :]],dim=0) + incremental_k = torch.cat([incremental_k, + k[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_qk)], dim=0) + incremental_v = torch.cat([incremental_v, + v[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_v)], dim=0) + else: + incremental_q = torch.zeros( + cur.batch_size, max_seqlen_q_infer, config.num_heads, config.head_dim_qk, + dtype=dtype, device='cuda') + incremental_k = torch.zeros( + cur.batch_size, max_seqlen_q_infer, config.num_gqa_groups, config.head_dim_qk, + dtype=dtype, device='cuda') + incremental_v = torch.zeros( + cur.batch_size, max_seqlen_q_infer, config.num_gqa_groups, config.head_dim_v, + dtype=dtype, device='cuda') + for i,seq in enumerate(cur.seq_ids): + start = (cur.total_lens[i]-cur.step_lens_q[i]).item() + end = cur.total_lens[i].item() + incremental_q[i, :cur.step_lens_q[i], :, :] = q[seq, start:end, :, :] + incremental_k[i, :cur.step_lens_q[i], :, :] = k[seq, start:end, :, :] + incremental_v[i, :cur.step_lens_q[i], :, :] = v[seq, start:end, :, :] + if qkv_format == 'sbhd': + incremental_q, incremental_k, incremental_v = [ + x.transpose(0,1) for x in [incremental_q, incremental_k, incremental_v]] + + cu_seqlens_q = torch.zeros(cur.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:cur.batch_size+1] = torch.cumsum(cur.step_lens_q, dim=0) + cu_seqlens_kv = torch.zeros(cur.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv[1:cur.batch_size+1] = torch.cumsum(cur.total_lens, dim=0) + + inference_params.step_dict = OrderedDict(zip(cur.seq_ids.tolist(), cur.step_lens_q.tolist())) + + line_output = model( + query_layer=incremental_q, + key_layer=incremental_k, + value_layer=incremental_v, + inference_params=inference_params, + attn_mask_type=attn_mask_type, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q_infer, + max_seqlen_kv=max_seqlen_kv_roundup, + qkv_format=qkv_format, + ) + + if backend != "FlashAttention": + tols = { + torch.float32: 1e-3, + torch.half: 1e-3, + torch.bfloat16: 1e-2, + } + else: + tols = { + torch.float32: 1e-3, + torch.half: 4e-3, + torch.bfloat16: 1e-2, + } + for i,seq in enumerate(cur.seq_ids): + if qkv_format == 'bshd': + torch.testing.assert_close( + full_output[seq,cur.total_lens[i]-1,:], + line_output[i,cur.step_lens_q[i]-1,:], + atol = tols[dtype], + rtol = tols[dtype]) + if qkv_format == 'sbhd': + torch.testing.assert_close( + full_output[seq,cur.total_lens[i]-1,:], + line_output[cur.step_lens_q[i]-1,i,:], + atol = tols[dtype], + rtol = tols[dtype]) + if qkv_format == 'thd': + torch.testing.assert_close( + full_output[seq,cur.total_lens[i]-1,:], + line_output[cu_seqlens_q[i+1]-1,:], + atol = tols[dtype], + rtol = tols[dtype]) + + prev = cur.copy() + t += 1 + + logger.info(f"total steps taken: {t}") + logger.info(f"arrival_times: {to_pretty_string(arrival_times)}") + logger.info(f"gen_lens: {to_pretty_string(gen_lens)}") + logger.info(f"serving_times: {to_pretty_string(arrival_times + request_delays)}") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c237dbaeb6..aef3fab070 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +from collections import OrderedDict import math import os from typing import Dict, List, Optional @@ -34,6 +35,7 @@ Fp8Padding, Fp8Unpadding, ) +from transformer_engine.pytorch.attention import _cu_seqlens_cache from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace @@ -72,7 +74,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), } -backends_inference = ["FlashAttention", "UnfusedAttention"] +backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] @@ -1935,14 +1937,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("backend", backends_inference) -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend): +@pytest.mark.parametrize("is_paged", [False, True]) +@pytest.mark.parametrize("is_cuda_graph", [False, True]) +def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged, is_cuda_graph): + reset_rng_states() + + if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: + pytest.skip("FusedAttention and FlashAttention do not support FP32") + os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" elif backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + elif backend == "UnfusedAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" config = model_configs_inference[model_key] @@ -1955,7 +1967,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, # Limits the max size of KV-cache B_max = B - S_max = S + 2 + S_max = S if module == "TransformerLayer": model = TransformerLayer( @@ -1985,9 +1997,24 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, .eval() ) - inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max) + inference_params = InferenceParams( + max_batch_size=B_max, + max_seqlen_kv=S_max, + num_heads_kv=H, + head_dim_k=head_size, + dtype=dtype, + is_paged=is_paged, + total_num_pages=4, + page_size=256, + is_cuda_graph=is_cuda_graph, + num_heads_q=H, + head_dim_q=head_size, + ) + rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") + inference_params.step_dict = OrderedDict(zip(list(range(B)), [1]*B)) + input = torch.randn((S, B, D), dtype=dtype, device="cuda") if input_format == "bshd": input = input.transpose(0, 1).contiguous() @@ -2004,16 +2031,31 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, else: incremental_input = input[:, i, :].view(B, 1, D) + seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") + cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + seqlens_kv = (i+1) * torch.ones(B, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(B + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) + + mask_type = "padding" + kwargs={} + if module == "TransformerLayer": + kwargs['self_attn_mask_type']=mask_type + else: + kwargs['attn_mask_type']=mask_type line_output = model( hidden_states=incremental_input, inference_params=inference_params, rotary_pos_emb=rotary_freqs if use_RoPE else None, + **kwargs, + max_seqlen_q=1, max_seqlen_kv=S, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) - inference_params.sequence_len_offset += 1 - if input_format == "sbhd": - incremental_output[i] = line_output.view(B, D) + incremental_output[i,:,:] = line_output.view(B, D) else: incremental_output[:, i, :] = line_output.view(B, D) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index ca23008edd..7ad30ee8b2 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -92,7 +92,8 @@ target_include_directories(transformer_engine PUBLIC # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas - CUDA::cudart) + CUDA::cudart + CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4ea0ea5741..c44b8fbe76 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -37,6 +37,14 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD: + return NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD; + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD: + return NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -50,18 +58,24 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD: return NVTE_QKV_Format::NVTE_SBHD; case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD: return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Layout::NVTE_T3HD: case NVTE_QKV_Layout::NVTE_TH3D: case NVTE_QKV_Layout::NVTE_THD_T2HD: case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD: return NVTE_QKV_Format::NVTE_THD; default: NVTE_ERROR("qkv_layout not supported!"); @@ -174,7 +188,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0) || + ((cudnn_runtime_version >= 90500) && + (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || + layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD))) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || @@ -613,7 +630,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, @@ -625,6 +642,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = reinterpret_cast(page_table_k); + const Tensor *input_page_table_v = reinterpret_cast(page_table_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_K = reinterpret_cast(K); @@ -635,11 +654,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; size_t t_kv = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -647,6 +667,30 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_q = input_Q->data.shape[0]; t_kv = input_K->data.shape[0]; } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD) { + num_pages_k = input_K->data.shape[0]; + page_size_k = input_K->data.shape[1]; + num_pages_v = input_V->data.shape[0]; + page_size_v = input_V->data.shape[1]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD) { + num_pages_k = input_K->data.shape[1]; + page_size_k = input_K->data.shape[0]; + num_pages_v = input_V->data.shape[1]; + page_size_v = input_V->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -668,10 +712,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( @@ -722,11 +766,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; size_t t_kv = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1a555a4999..5ab2452b49 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -49,12 +49,12 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -65,16 +65,23 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); + if (is_paged_kv) { + NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); + } + // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if (is_ragged && cudnn_runtime_version >= 90600) { @@ -95,6 +102,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_kv, d_qk, d_v, + num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, scaling_factor, @@ -121,6 +129,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // bias std::shared_ptr, // seq_q std::shared_ptr, // seq_kv + std::shared_ptr, // page_table_k + std::shared_ptr, // page_table_v std::shared_ptr, // offset_q std::shared_ptr, // offset_k std::shared_ptr, // offset_v @@ -149,6 +159,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr page_table_k, page_table_v; std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; @@ -158,17 +169,40 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector v_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + if (is_paged_kv) { + generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + } else { + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + } + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride)); if (is_ragged) { offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_q") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Q->set_ragged_offset(offset_q); + } + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_stride(v_stride)); + if (is_paged_kv) { + K->set_dim({num_pages_k, hg, page_size_k, d_qk}); + V->set_dim({num_pages_v, hg, page_size_v, d_v}); + } else if (is_ragged) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -179,34 +213,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_ragged_offset(offset_q)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_ragged_offset(offset_k)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_ragged_offset(offset_v)); + K->set_dim({b, hg, s_kv, d_qk}).set_ragged_offset(offset_k); + V->set_dim({b, hg, s_kv, d_v}).set_ragged_offset(offset_v); } else { - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride)); + K->set_dim({b, hg, s_kv, d_qk}); + V->set_dim({b, hg, s_kv, d_v}); } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -252,6 +263,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); } + if (is_paged_kv) { + page_table_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_k") + .set_dim({b, 1, max_pages_per_seq_k, 1}) + .set_stride({{max_pages_per_seq_k, max_pages_per_seq_v, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + page_table_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_v") + .set_dim({b, 1, max_pages_per_seq_v, 1}) + .set_stride({{max_pages_per_seq_v, max_pages_per_seq_v, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_paged_attention_k_table(page_table_k); + sdpa_options.set_paged_attention_v_table(page_table_v); + sdpa_options.set_paged_attention_max_seq_len_kv(static_cast(s_kv)); + } + if (is_dropout) { dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seed") @@ -271,20 +298,19 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); if (is_ragged) { offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_o") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - O->set_output(true) - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_ragged_offset(offset_o); - } else { - O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); + O->set_ragged_offset(offset_o); } + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}); if (is_ragged && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -292,16 +318,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, 1, h, 1}) - .set_ragged_offset(offset_stats); + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + Stats->set_stride({h * s_q, s_q, 1, 1}); } std::tuple, // Q @@ -314,6 +333,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto page_table_tuple = + is_paged_kv ? std::make_tuple(page_table_k, page_table_v) : std::make_tuple(nullptr, nullptr); auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) : std::make_tuple(nullptr, nullptr, nullptr, nullptr); auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) @@ -330,13 +351,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, - padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); + padding_tuple, page_table_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); @@ -389,6 +410,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } + if (is_paged_kv) { + variant_pack[page_table_k] = devPtrPageTableK; + variant_pack[page_table_v] = devPtrPageTableV; + } + if (is_ragged) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; @@ -445,17 +471,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); + if (is_paged_kv) { + NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); + } + // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if (is_ragged && cudnn_runtime_version >= 90600) { @@ -479,6 +512,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( s_kv, d_qk, d_v, + 0,0,0,0,0,0, bias_b, bias_h, scaling_factor, @@ -555,6 +589,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride)); if (is_ragged) { offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_q") @@ -576,53 +630,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_ragged_offset(offset_q)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_ragged_offset(offset_k)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_ragged_offset(offset_v)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_ragged_offset(offset_o)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_ragged_offset(offset_o)); - } else { - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride)); + q->set_ragged_offset(offset_q); + k->set_ragged_offset(offset_k); + v->set_ragged_offset(offset_v); + o->set_ragged_offset(offset_o); } + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); if (is_ragged && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -630,18 +646,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, 1, h, 1}) - .set_data_type(fe::DataType_t::FLOAT) - .set_ragged_offset(offset_stats)); + stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + stats->set_stride({h * s_q, s_q, 1, 1}); } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -722,23 +729,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); if (is_ragged) { - dQ->set_output(true) - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_ragged_offset(offset_q); - dK->set_output(true) - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_ragged_offset(offset_k); - dV->set_output(true) - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_ragged_offset(offset_v); - } else { - dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); + dQ->set_ragged_offset(offset_q); + dK->set_ragged_offset(offset_k); + dV->set_ragged_offset(offset_v); } std::tuple, // q @@ -985,10 +982,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, + max_batch_size, max_tokens, max_tokens, 0,0,0,0,0,0, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1201,10 +1198,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + max_batch_size, max_tokens_q, max_tokens_kv, 0,0,0,0,0,0, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, nullptr, nullptr, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1317,12 +1314,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, - size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1346,6 +1343,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + void *devPtrPageTableK = page_table_k->data.dptr; + void *devPtrPageTableV = page_table_v->data.dptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; @@ -1415,10 +1414,10 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 3a1216f891..cf6b2664bb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -61,12 +61,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, - size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fb7765e1a4..096b8e6ac5 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1680,6 +1680,7 @@ void fused_attn_fp8_fwd_impl_v1( s_kv, d, d, + 0,0,0,0,0,0, bias_b, bias_h, scaling_factor, @@ -1984,6 +1985,7 @@ void fused_attn_fp8_bwd_impl_v1( s_kv, d, d, + 0,0,0,0,0,0, bias_b, bias_h, scaling_factor, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index ca00218d9a..f7a9c8a8c6 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -116,6 +116,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 } break; case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) || @@ -222,6 +223,8 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 break; case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { strideA[batch_dim_idx] = s_q * h * d; @@ -242,6 +245,49 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_transpose_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index c060c4907d..34c25307c2 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -93,6 +93,12 @@ struct FADescriptor_v1 { std::int64_t s_kv; std::int64_t d_qk; std::int64_t d_v; + std::int64_t num_pages_k; + std::int64_t num_pages_v; + std::int64_t page_size_k; + std::int64_t page_size_v; + std::int64_t max_pages_per_seq_k; + std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; float attnScale; @@ -108,10 +114,10 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ae08f2a4aa..12e96f6d0a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -25,24 +25,30 @@ extern "C" { * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * or padded to the same length, and `THD`-based layouts are used when sequences have - * different lengths in a batch. + * different lengths in a batch. `Paged_KV`-based layouts are used for paged attention. */ enum NVTE_QKV_Layout { - NVTE_SB3HD = 0, /*!< SB3HD layout */ - NVTE_SBH3D = 1, /*!< SBH3D layout */ - NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ - NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ - NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ - NVTE_BS3HD = 5, /*!< BS3HD layout */ - NVTE_BSH3D = 6, /*!< BSH3D layout */ - NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ - NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ - NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ - NVTE_T3HD = 10, /*!< T3HD layout */ - NVTE_TH3D = 11, /*!< TH3D layout */ - NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ - NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ - NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_SB3HD = 0, /*!< SB3HD layout */ + NVTE_SBH3D = 1, /*!< SBH3D layout */ + NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ + NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ + NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ + NVTE_BS3HD = 5, /*!< BS3HD layout */ + NVTE_BSH3D = 6, /*!< BSH3D layout */ + NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ + NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ + NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ + NVTE_T3HD = 10, /*!< T3HD layout */ + NVTE_TH3D = 11, /*!< TH3D layout */ + NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ + NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ + NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_Paged_KV_BSHD_2BSHD = 15, /*!< Paged_KV_BSHD_2BSHD layout */ + NVTE_Paged_KV_BSHD_2SBHD = 16, /*!< Paged_KV_BSHD_2SBHD layout */ + NVTE_Paged_KV_SBHD_2BSHD = 17, /*!< Paged_KV_SBHD_2BSHD layout */ + NVTE_Paged_KV_SBHD_2SBHD = 18, /*!< Paged_KV_SBHD_2SBHD layout */ + NVTE_Paged_KV_THD_2BSHD = 19, /*!< Paged_KV_THD_2BSHD layout */ + NVTE_Paged_KV_THD_2SBHD = 20, /*!< Paged_KV_THD_2SBHD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -59,17 +65,21 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_H2D = 3, /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ NVTE_HD_HD_HD = 4, + /*! Paged_KV_2BSHD QKV layouts, e.g. Paged_KV_THD_2BSHD */ + NVTE_Paged_KV_2BSHD = 5, + /*! Paged_KV_2SBHD QKV layouts, e.g. Paged_KV_BSHD_2SBHD */ + NVTE_Paged_KV_2SBHD = 6, }; /*! \enum NVTE_QKV_Format * \brief QKV formats */ enum NVTE_QKV_Format { - /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ + /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_2BSHD, Paged_KV_SBHD_2SBHD */ NVTE_SBHD = 0, - /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ + /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_2BSHD, Paged_KV_BSHD_2SBHD */ NVTE_BSHD = 1, - /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ + /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD, Paged_KV_THD_2BSHD, Paged_KV_THD_2SBHD */ NVTE_THD = 2, }; @@ -445,6 +455,8 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k]. + * \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -465,7 +477,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..4a2938f9e3 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -51,7 +51,13 @@ .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ + .value("NVTE_Paged_KV_BSHD_2BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD) \ + .value("NVTE_Paged_KV_BSHD_2SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD) \ + .value("NVTE_Paged_KV_SBHD_2BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD) \ + .value("NVTE_Paged_KV_SBHD_2SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD) \ + .value("NVTE_Paged_KV_THD_2BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD) \ + .value("NVTE_Paged_KV_THD_2SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3d72c6a9b3..13cd7ac637 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -13,6 +13,7 @@ import warnings import logging import functools +from einops import rearrange from dataclasses import dataclass, fields import numpy as np @@ -85,6 +86,8 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.kv_cache_manager_paged import PagedKVCacheManager +from transformer_engine.pytorch.kv_cache_manager_non_paged import NonPagedKVCacheManager # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 @@ -271,6 +274,8 @@ class AttentionParams: Whether `DotProductAttention` is in an `fp8_autocast` region. fp8_meta: Optional[Dict[str Any]], default = `None` The FP8 metadata tensor of `DotProductAttention`. + inference_params: Optional[object], default = `None` + Inference-related parameters. See InferenceParams for details. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -296,6 +301,7 @@ class AttentionParams: is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + inference_params: Optional[object] = None _alibi_cache = { @@ -365,6 +371,7 @@ def get_attention_backend( is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta + inference_params = attention_params.inference_params # Run config logger = logging.getLogger("DotProductAttention") @@ -469,6 +476,25 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") use_unfused_attention = False + # Filter: KV cache + # backend | non-paged/paged | precision + # --------------------------------------------------------------------------------- + # FlashAttention | non-paged/paged | FP16/BF16 + # FusedAttention | non-paged/paged | FP16/BF16 + # UnfusedDotProductAttention | non-paged/paged | FP32/FP16/BF16 + if inference_params is not None: + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling all backends as FP8 KV caching is not yet implemented") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if inference_params.is_paged: + if use_fused_attention and cudnn_version < (9, 5, 0): + logger.debug( + "Disabling FusedAttention as paged KV caching requires cuDNN 9.5+" + ) + use_fused_attention = False + # Filter: Head dimension if use_flash_attention and head_dim_qk != head_dim_v: if _flash_attn_is_installed: @@ -499,7 +525,7 @@ def get_attention_backend( use_fused_attention = False # Filter: QKV layout - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) if qkv_format == "thd": if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") @@ -980,7 +1006,8 @@ def get_attention_backend( class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order - to efficiently calculate and store the context during inference. + to efficiently calculate and store the context and previously generated tokens + during inference. Parameters ---------- @@ -988,39 +1015,351 @@ class InferenceParams: # pylint: disable=too-few-public-methods maximum batch size during inference. max_sequence_length : int maximum sequence length during inference. + num_heads: int + number of attention heads in key/value tensor. + head_dim_k: int + head size for the key tensor. + dtype: torch.dtype + data type for the KV cache. + head_dim_v: Optional[int], default = None + head size for the value tensor. If None, it will be set to head_dim_k. + is_paged: bool, default = False + whether the KV cache is paged or non-paged (contiguous). + total_num_pages: Optional[int], default = None + total number of pages in the K cache or V cache if is_paged = True. + page_size: Optional[int], default = None + page size in number of tokens if is_paged = True. """ - def __init__(self, max_batch_size, max_sequence_length): - self.max_sequence_length = max_sequence_length + def __init__(self, + max_batch_size: int, + max_seqlen_kv: int, + num_heads_kv: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + is_paged: bool = False, + total_num_pages: Optional[int] = None, + page_size: Optional[int] = None, + is_cuda_graph: bool = False, + num_heads_q: Optional[int] = None, + head_dim_q: Optional[int] = None, + ): self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 - self.key_value_memory_dict = {} + self.max_seqlen_kv = max_seqlen_kv + self.num_heads_kv = num_heads_kv + self.head_dim_k = head_dim_k + assert ( + dtype in [torch.float32, torch.float16, torch.bfloat16] + ), "Supported InferenceParams.dtype = {torch.float32, torch.float16, torch.bfloat16}. Found {dtype}." + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + self.is_paged = is_paged + self.is_cuda_graph = is_cuda_graph + self.page_table = None + + if not self.is_paged: + self.cache_manager = NonPagedKVCacheManager( + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + head_dim_v=self.head_dim_v, + is_cuda_graph=self.is_cuda_graph, + ) + else: + assert page_size is not None, "page_size is required when is_paged=True!" + assert total_num_pages is not None, "total_num_pages is required when is_paged=True!" + self.page_size = page_size + self.max_seqlen_kv = self.max_seqlen_kv if self.max_seqlen_kv >= self.page_size else int((self.max_seqlen_kv + self.page_size -1)//self.page_size * self.page_size) + self.total_num_pages = total_num_pages + self.cache_manager = PagedKVCacheManager( + total_num_pages=self.total_num_pages, + page_size=self.page_size, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + head_dim_v=self.head_dim_v, + is_cuda_graph=self.is_cuda_graph, + ) - def swap_key_value_dict(self, batch_indices): + if self.is_cuda_graph: + assert num_heads_q is not None, "num_heads_q is required when is_cuda_graph=True!" + assert head_dim_q is not None, "head_dim_q is required when is_cuda_graph=True!" + self.num_heads_q = num_heads_q + self.head_dim_q = head_dim_q + + # memory format for the cache; at the moment, only 'bshd' is supported + self.qkv_format = 'bshd' + # layer numbers that we have kv cache for + self.layer_numbers = [] + # sequence ids that are stored in the cache + self.seq_ids = [] + # the full sequence lengths for sequences in seq_ids + self.seqlens = [0] * self.max_batch_size + # the {seq_id: step_len} information for a new inference step + # e.g. inference_params.step_dict = {2: 1, 3: 1, 4: 10}, if in this iteration, + # we have three sequences in the batch: sequences 2 and 3 are in generation phase + # with step_len = 1 and sequence 4 is in context phase with 10 new tokens + self.step_dict = collections.OrderedDict() + # the query buffer when is_cuda_graph = True + if self.is_cuda_graph: + self.q_buffer = {} + self.cu_seqlens_q_buffer = [] + self.cu_seqlens_kv_buffer = [] + + def print(self): + """Print InferenceParams parameters""" + logger = logging.getLogger("InferenceParams") + logger.debug(f"InferenceParams:") + logger.debug(f" dtype: {self.dtype}") + logger.debug(f" is_paged: {self.is_paged}") + if not self.is_paged: + logger.debug(f" max_batch_size: {self.max_batch_size}") + logger.debug(f" max_seqlen_kv: {self.max_seqlen_kv}") + else: + logger.debug(f" total_num_pages: {self.total_num_pages}") + logger.debug(f" page_size: {self.page_size}") + logger.debug(f" num_heads_kv: {self.num_heads_kv}") + logger.debug(f" head_dim: k: {self.head_dim_k}, v: {self.head_dim_v}") + logger.debug(f" layer_numbers: {self.layer_numbers}") + + + def allocate_memory(self, layer_number): + """ + Allocate memory for the KV cache for the layer #layer_number. + Both K cache and V cache are in 'bshd' format. + - non-paged: + - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] + - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] + - paged: + - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] + - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] + If is_cuda_graph = True, several buffers are also allocated. + - Q buffer: [max_batch_size, max_seqlen_kv, num_heads_q, head_dim_q] + - cu_seqlens_q buffer: [max_batch_size + 1] + - cu_seqlens_kv buffer: [max_batch_size + 1] """ - Reorders the KV cache using the specified batch indices. + self.layer_numbers.append(layer_number) + self.cache_manager.allocate_memory(layer_number) + + if self.is_cuda_graph: + self.max_seqlen_q = self.max_seqlen_kv + self.q_buffer[layer_number] = torch.zeros( + self.max_batch_size, + self.max_seqlen_q, + self.num_heads_q, + self.head_dim_q, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_q_buffer = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_kv_buffer = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) - Parameters - ---------- - batch_indices : List[int] - Sequence of indices to reorder along the batch dimensions of - the KV cache. Must have a length equal to the batch size. + def reshape_and_copy_q(self, + q: torch.Tensor, + source_qkv_format: str, + target_qkv_format: str, + layer_number: Optional[int] = None, + ): + """ + Convert the new query tokens from 'source_qkv_format' to 'target_qkv_format', + so that it is consistent with the KV cache format. At the moment, only 'bshd' format + is supported for target_qkv_format. If is_cuda_graph = True, also copy the new query + tensor to the appropriate q_buffer. """ - if len(self.key_value_memory_dict) == 0: - raise ValueError("should not swap when dict in empty") + actual_batch_size = len(self.step_dict) + seqlens_q = list(self.step_dict.values()) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size+1)] + batch_wide_max_seqlen_q = int((max(seqlens_q) + 63)//64 * 64) + if not self.is_cuda_graph: + if source_qkv_format == 'bshd': + q = q.contiguous() + if source_qkv_format == 'sbhd': + q = q.transpose(0,1).contiguous() + if source_qkv_format == 'thd': + padded_q = torch.zeros( + actual_batch_size, batch_wide_max_seqlen_q, q.shape[-2], q.shape[-1], + dtype=q.dtype, device='cuda') + for i in range(actual_batch_size): + padded_q[i, :seqlens_q[i], :, :] = q[cu_seqlens_q[i]:cu_seqlens_q[i+1], :, :] + q = padded_q + + if source_qkv_format in ['bshd', 'sbhd']: + self.max_seqlen_q = q.shape[1] + else: + self.max_seqlen_q = batch_wide_max_seqlen_q - for layer_number, inference_memory in self.key_value_memory_dict.items(): - inference_key_memory, inference_value_memory = inference_memory + # bshd: [actual_batch_size, batch_wide_max_seqlen_q, num_heads_q, head_dim_q] + return q + else: assert ( - len(batch_indices) == inference_key_memory.shape[1] - ) # make sure batch size is the same - new_inference_key_memory = inference_key_memory[:, batch_indices] - new_inference_value_memory = inference_value_memory[:, batch_indices] - self.key_value_memory_dict[layer_number] = ( - new_inference_key_memory, - new_inference_value_memory, + layer_number is not None and layer_number in self.layer_numbers + ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" + q_buffer = self.q_buffer[layer_number] + for i in range(actual_batch_size): + if source_qkv_format == 'bshd': + q_buffer[i, :seqlens_q[i], :, :] = q[i, :seqlens_q[i], :, :] + if source_qkv_format == 'sbhd': + q_buffer[i, :seqlens_q[i], :, :] = q[:seqlens_q[i], i, :, :] + if source_qkv_format == 'thd': + q_buffer[i, :seqlens_q[i], :, :] = q[cu_seqlens_q[i]:cu_seqlens_q[i+1], :, :] + q_buffer[i, seqlens_q[i]:, :, :].fill_(0) + + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]]*(self.max_batch_size - actual_batch_size) + self.cu_seqlens_q_buffer.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device='cpu')) + + # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] + return q_buffer + + def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): + """ + Convert the k cache and v cache from paged to non-paged format. This function + can be used for debugging purposes or for backends that do not have paged attention + support yet, for example, UnfusedDotProductAttention. + + It can be called after update_cache(). Based on the page table, it re-indexes the cache + tensors and returns the contiguous, non-paged, key and value tensors. The kv cache tensors + are assumed to be in 'bshd' format (see self.allocate_memory), and the returned key and + value tensors will be in :attr:`qkv_format` to be consistent with the original inputs. + + Parameters + ---------- + layer_number: int + The layer number of the kv cache + qkv_format: str + The format of the returned key and value tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Non-paged key cache tensor + v_cache: torch.Tensor + Non-paged value cache tensor + """ + k_cache, v_cache = self.cache_manager.cache[layer_number] + page_table = self.cache_manager.page_table + batch_size = page_table.shape[0] + actual_batch_size = len(self.step_dict) + new_k_cache = rearrange( + k_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) + new_v_cache = rearrange( + v_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, ) + for i in range(actual_batch_size): + new_k_cache[i, self.seqlens[i]:,:,:].fill_(0) + new_v_cache[i, self.seqlens[i]:,:,:].fill_(0) + if qkv_format == 'bshd': + new_k_cache = new_k_cache.contiguous() + new_v_cache = new_v_cache.contiguous() + if qkv_format == 'sbhd': + new_k_cache = new_k_cache.transpose(0,1).contiguous() + new_v_cache = new_v_cache.transpose(0,1).contiguous() + if qkv_format == 'thd': + packed_k_cache = torch.Tensor().to(dtype=k_cache.dtype,device=k_cache.device) + packed_v_cache = torch.Tensor().to(dtype=v_cache.dtype,device=v_cache.device) + for i in range(batch_size): + packed_k_cache = torch.cat([packed_k_cache, new_k_cache[i,:self.seqlens[i],:,:]], dim=0) + packed_v_cache = torch.cat([packed_v_cache, new_v_cache[i,:self.seqlens[i],:,:]], dim=0) + new_k_cache = packed_k_cache.contiguous() + new_v_cache = packed_v_cache.contiguous() + return new_k_cache, new_v_cache + + def update_cache(self, + layer_number: int, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str, + ): + """ + Update KV cache with the new key/value tokens for a given inference iteration. + + NonPagedKVCacheManager and PagedKVCacheManager are two examples of the cache manager. + Users can write their own cache manager with their own step() function. + + If the inference iteration has only generation sequences, :attr:`k` and :attr:`v` tensors + should have shape: + - [batch_size, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', + - [1, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and + - [batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. + + If the inference iteration has both generation sequences and context sequences, :attr:`k` + and :attr:`v` should be arranged in a way so that the sequences in generation phase come + before the sequences in context phase, in the tensor. They should have the following shape. + - [batch_size, max_seqlen, num_heads, head_dim] for :attr:`qkv_format` = 'bshd' + - [max_seqlen, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and + - [total_num_new_tokens, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. + Here, max_seqlen is the maximum sequence length for the new tokens in the batch, and it may + be smaller than InferenceParams.max_seqlen_kv. + + Take a batch of 4, with seq_ids = [0, 1, 2, 3], as an example. At iteration t, all 4 sequences + are processed, after which, sequence 2 is determined to be 'finished'. For iteration t+1, there + may or may not be a new sequence added to the batch. + + If no new sequence is added, input tensors :attr:`k` and :attr:`v` should have shape + [3, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', [1, 3, num_heads, head_dim] for + :attr:`qkv_format` = 'sbhd', and [3, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. + + If one new sequence is added, for example, sequence 8 with 10 context tokens, then input tensors + :attr:`k` and :attr:`v` should be in [4, 10, num_heads, head_dim] shape if + :attr:`qkv_format` = 'bshd', [10, 4, num_heads, head_dim] if :attr:`qkv_format` = 'sbhd', + or [13, num_heads, head_dim] if :attr:`qkv_format` = 'thd'. + + Parameters + ---------- + layer_number: int + The layer number of the kv cache + k: torch.Tensor + The new key tokens for the current iteration + v: torch.Tensor + The new value tokens for the current iteration + qkv_format: str + The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + The key cache tensor, containing tokens from both previous and current iterations + v_cache: torch.Tensor + The value cache tensor, containing tokens from both previous and current iterations + page_table: torch.Tensor + The page table if is_paged = True; else `None` + """ + outputs = self.cache_manager.step(layer_number, k, v, self.step_dict, qkv_format) + self.seq_ids = list(self.cache_manager.sequences.keys()) + self.seqlens = list(self.cache_manager.sequences.values()) + + if not self.is_paged: + k_cache, v_cache = outputs + page_table = None + else: + k_cache, v_cache, page_table = outputs + self.page_table = page_table + + if self.is_cuda_graph: + actual_batch_size = len(self.seqlens) + cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size+1)] + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]]*(self.max_batch_size - actual_batch_size) + self.cu_seqlens_kv_buffer.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device='cpu')) + + # k_cache and v_cache are in InferenceParams.qkv_format format + return k_cache, v_cache, page_table @torch.no_grad() @@ -4736,12 +5075,17 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, ) -> torch.Tensor: """Unfused attention fprop""" assert ( qkv_layout in QKVLayouts ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + + qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) + if inference_params is not None and inference_params.is_paged: + key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number, qkv_format) + if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now query_layer, key_layer, value_layer = [ @@ -4754,6 +5098,24 @@ def forward( ) if "padding" in attn_mask_type: if self.attention_type == "self": + if attention_mask is None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + for i in range(batch_size): + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor( + [False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask = attention_mask_q.to(device="cuda") assert attention_mask.shape == ( batch_size, 1, @@ -4764,6 +5126,43 @@ def forward( attention_mask.squeeze(1).unsqueeze(3), attention_mask ) else: + if attention_mask is None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) + for i in range(batch_size): + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor( + [False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask_kv = torch.cat( + [ + attention_mask_kv, + torch.Tensor( + [False] * seqlens_kv[i] + + [True] * (max_seqlen_kv - seqlens_kv[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask = ( + attention_mask_q.to(device="cuda"), + attention_mask_kv.to(device="cuda"), + ) assert ( len(attention_mask) == 2 and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) @@ -5228,9 +5627,9 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + inference_params: Optional[InferenceParams] = None, ) -> torch.Tensor: """flash-attn fprop""" - assert all( x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] @@ -5250,7 +5649,7 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if qkv_format == "sbhd": @@ -5295,39 +5694,52 @@ def forward( if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" - # [b * s, h, d] - query_layer, key_layer, value_layer = [ - x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) - for x in [query_layer, key_layer, value_layer] - ] - if self.attention_type == "self": - assert ( - max_seqlen_q == max_seqlen_kv - ), "Maximum sequence length for Q and KV should be the same." - if cu_seqlens_q is None: + if inference_params is None or (inference_params is not None and not inference_params.is_paged): + # [b * s, h, d] + query_layer, key_layer, value_layer = [ + x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) + for x in [query_layer, key_layer, value_layer] + ] + + if self.attention_type == "self": assert ( - attention_mask is not None - ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask) + max_seqlen_q == max_seqlen_kv + ), "Maximum sequence length for Q and KV should be the same." + if cu_seqlens_q is None: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask) + else: + indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + cu_seqlens_kv = cu_seqlens_q + query_layer, key_layer, value_layer = PackTensors.apply( + indices_q, query_layer, key_layer, value_layer + ) else: - indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - cu_seqlens_kv = cu_seqlens_q - query_layer, key_layer, value_layer = PackTensors.apply( - indices_q, query_layer, key_layer, value_layer - ) + if cu_seqlens_q is None or cu_seqlens_kv is None: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) + cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) + else: + indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) + query_layer = PackTensors.apply(indices_q, query_layer) + key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) else: - if cu_seqlens_q is None or cu_seqlens_kv is None: + # [b * s, h, d] + query_layer = query_layer.reshape(query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:]) + if cu_seqlens_q is None: assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) - cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask if self.attention_type == "self" else attention_mask[0]) else: indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) query_layer = PackTensors.apply(indices_q, query_layer) - key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) else: # Cumulative sequence lengths for unpadded data if cu_seqlens_q is None: @@ -5406,6 +5818,8 @@ def forward( else: if _flash_attn_2_5_7_plus: fa_optional_forward_kwargs["block_table"] = None + if inference_params is not None: + fa_optional_forward_kwargs["block_table"] = inference_params.page_table func = ( flash_attn_varlen_func if not _use_flash_attn_3 @@ -5589,7 +6003,7 @@ def forward( fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) assert ( qkv_group == 1 ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." @@ -5983,7 +6397,7 @@ def forward( q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) assert qkv_group == 2, ( "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " f"but found {qkv_layout}." @@ -6404,6 +6818,8 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, + page_table_k, + page_table_v, q, k, v, @@ -6438,7 +6854,7 @@ def forward( q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) if qkv_group == 1: dim = qkv_layout.find("3") qkv = _combine_tensors([q, k, v], dim) @@ -6452,7 +6868,7 @@ def forward( q_fp8 = cast_to_fp8( q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") + dim = qkv_layout.replace("paged_kv_","").split("_")[1].find("2") kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv_fp8 = cast_to_fp8( @@ -6527,7 +6943,7 @@ def forward( if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate if is_input_fp8: - qkv_group = len(qkv_layout.split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) if qkv_group == 1: dim = qkv_layout.find("3") qkv = _combine_tensors([q, k, v], dim) @@ -6549,7 +6965,7 @@ def forward( fp8_dtype_forward, TE_DType[q.dtype], ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") + dim = qkv_layout.replace("paged_kv_","").split("_")[1].find("2") kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv_no_fp8 = cast_from_fp8( @@ -6615,6 +7031,8 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, + page_table_k, + page_table_v, None, # d_scale_qkv 0, # d_scale_qkv_offset None, # d_scale_s @@ -6827,7 +7245,7 @@ def backward(ctx, d_out): dtype=d_out_f8tensor.dtype, ) else: - qkv_group = len(ctx.qkv_layout.split("_")) + qkv_group = len(ctx.qkv_layout.replace("paged_kv_","").split("_")) if qkv_group == 1: dim = ctx.qkv_layout.find("3") dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim) @@ -6851,7 +7269,7 @@ def backward(ctx, d_out): fp8_dtype_backward, ctx.qkv_dtype, ).view(dq_fp8.shape) - dim = ctx.qkv_layout.split("_")[1].find("2") + dim = ctx.qkv_layout.replace("paged_kv_","").split("_")[1].find("2") dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim) dkv_c_fp8 = dkv_fp8.view( -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] @@ -6936,6 +7354,8 @@ def backward(ctx, d_out): None, None, None, + None, + None, dq, dk, dv, @@ -6966,6 +7386,8 @@ def backward(ctx, d_out): None, None, None, + None, + None, dq, dk, dv, @@ -7070,6 +7492,8 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_kv_padded: Optional[torch.Tensor] = None, + page_table_k: Optional[torch.Tensor] = None, + page_table_v: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", @@ -7109,21 +7533,17 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) if qkv_format in ["sbhd", "bshd"]: if qkv_format == "sbhd": - batch_size, max_seqlen_q, max_seqlen_kv = ( - query_layer.shape[1], - query_layer.shape[0], - key_layer.shape[0], - ) + batch_size = query_layer.shape[1] + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv if qkv_format == "bshd": - batch_size, max_seqlen_q, max_seqlen_kv = ( - query_layer.shape[0], - query_layer.shape[1], - key_layer.shape[1], - ) + batch_size = query_layer.shape[0] + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv max_seqlen_q *= cp_size max_seqlen_kv *= cp_size if "padding" in attn_mask_type: @@ -7233,6 +7653,8 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, + page_table_k, + page_table_v, query_layer, key_layer, value_layer, @@ -7875,10 +8297,6 @@ def forward( assert ( attn_mask_type in AttnMaskTypes ), f"Attention mask type {attn_mask_type} is not supported!" - if qkv_format == "thd": - assert ( - "padding" in attn_mask_type - ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" if window_size is None: window_size = self.window_size @@ -7895,52 +8313,6 @@ def forward( if qkv_format is None: qkv_format = self.qkv_format - if inference_params is not None: - assert self.layer_number is not None, "Layer number must be set!" - - # convert causal to causal_bottom_right in inference when KV-caching is in use - # so users can run with the same attn_mask_type for training and inference - if attn_mask_type in ["causal", "padding_causal"]: - attn_mask_type = attn_mask_type + "_bottom_right" - - if qkv_format == "bshd": - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - - # Copy keys and values into KV-cache - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - key_layer - ) - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - value_layer - ) - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - if qkv_format == "bshd": - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" assert qkv_format in [ "sbhd", "bshd", @@ -7976,6 +8348,65 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + page_table = None + if inference_params is not None: + assert self.layer_number is not None, "Layer number must be set!" + + # remember original format for output purposes + orig_qkv_format = qkv_format + + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + + # convert to cross attention type when KV cache is in use + self.attention_type = "cross" + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type + + # force tensors to be contiguous if not already + query_layer, key_layer, value_layer = [ + x.contiguous() if not x.is_contiguous() else x for x in [ + query_layer, key_layer, value_layer]] + + # reshape the query tensor + # cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but + # flash-attention and unfused attention will need q/k/v in the + # same qkv_format + target_qkv_format = inference_params.qkv_format + query_layer = inference_params.reshape_and_copy_q( + query_layer, qkv_format, target_qkv_format, self.layer_number) + + # update KV cache and return the full key/value tensors + # full key/value tensors are in inference_params.qkv_format format + key_layer, value_layer, page_table = inference_params.update_cache( + self.layer_number, + key_layer, + value_layer, + qkv_format, + ) + + # update cu_seqlens tensors + if inference_params.is_cuda_graph: + cu_seqlens_q = inference_params.cu_seqlens_q_buffer + cu_seqlens_kv = inference_params.cu_seqlens_kv_buffer + max_seqlen_q = inference_params.max_seqlen_q + max_seqlen_kv = inference_params.max_seqlen_kv + + # query tensor is now in inference_params.qkv_format + qkv_format = target_qkv_format + + if qkv_format == "thd": + assert ( + "padding" in attn_mask_type + ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" + assert ( + key_layer.shape[-2] == self.num_gqa_groups_per_partition + and value_layer.shape[-2] == self.num_gqa_groups_per_partition + ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups_per_partition} heads! Found {key_layer.shape[-2]} in key_layer and {value_layer.shape[-2]} in value_layer." + cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -7998,18 +8429,6 @@ def forward( batch_size = query_layer.shape[0] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size - if cu_seqlens_q is not None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - assert all( - seqlens_q <= max_seqlen_q - ), """Sequence lengths indicated by cu_seqlens_q must be no greater than - the sequence dimension in 'query_layer'!""" - if cu_seqlens_kv is not None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - assert all( - seqlens_kv <= max_seqlen_kv - ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than - the sequence dimension in 'key_layer' and 'value_layer'!""" if cu_seqlens_q is None or cu_seqlens_kv is None: if "padding" in attn_mask_type: assert ( @@ -8045,6 +8464,9 @@ def forward( qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout( query_layer, key_layer, value_layer, qkv_format=qkv_format ) + # convert qkv layout to its corresponding paged attention layout + if inference_params is not None and inference_params.is_paged: + qkv_layout = "paged_kv_"+qkv_format+"_2"+inference_params.qkv_format global _alibi_cache if alibi_slopes is not None: @@ -8126,6 +8548,7 @@ def forward( is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, + inference_params=inference_params, ) global _attention_backends, _use_flash_attn_3 if ( @@ -8161,6 +8584,10 @@ def forward( fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] + if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: + raise ValueError("No dot product attention support for the provided inputs!") + + output = None if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = get_alibi( @@ -8169,7 +8596,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, ) - return self.flash_attention( + output = self.flash_attention( query_layer, key_layer, value_layer, @@ -8188,6 +8615,7 @@ def forward( max_seqlen_kv=max_seqlen_kv, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + inference_params=inference_params, ) if use_fused_attention: @@ -8206,7 +8634,7 @@ def forward( bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) if checkpoint_core_attention: - return self._checkpointed_attention_forward( + output = self._checkpointed_attention_forward( self.fused_attention, query_layer, key_layer, @@ -8216,6 +8644,8 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, + page_table_k=page_table, + page_table_v=page_table, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, @@ -8232,7 +8662,7 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, ) - return self.fused_attention( + output = self.fused_attention( query_layer, key_layer, value_layer, @@ -8241,6 +8671,8 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, + page_table_k=page_table, + page_table_v=page_table, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, @@ -8274,7 +8706,7 @@ def forward( window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask ) if checkpoint_core_attention: - return self._checkpointed_attention_forward( + output = self._checkpointed_attention_forward( self.unfused_attention, query_layer, key_layer, @@ -8287,8 +8719,9 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + inference_params=inference_params, ) - return self.unfused_attention( + output = self.unfused_attention( query_layer, key_layer, value_layer, @@ -8300,9 +8733,25 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + inference_params=inference_params, ) - raise ValueError("No dot product attention support for the provided inputs!") + if inference_params is not None: + batch_size = len(inference_params.step_dict) + seqlen = inference_params.seqlens[0] + step_lens = list(inference_params.step_dict.values()) + max_seqlen_q = max(list(inference_params.step_dict.values())) + if orig_qkv_format == "bshd": + output = output[:batch_size, :max_seqlen_q].contiguous() + if orig_qkv_format == "sbhd": + output = output[:batch_size, :max_seqlen_q].transpose(0,1).contiguous() + if orig_qkv_format == "thd": + packed_output = torch.Tensor().to(dtype=output.dtype,device=output.device) + for i in range(batch_size): + packed_output = torch.cat([packed_output, output[i,:step_lens[i]]], dim=0) + output = packed_output.contiguous() + + return output class MultiheadAttention(torch.nn.Module): @@ -8486,7 +8935,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = check_set_window_size(attn_mask_type, window_size) - self.layer_number = layer_number + self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type self.get_rng_state_tracker = get_rng_state_tracker @@ -8855,28 +9304,8 @@ def forward( # Pre-allocate memory for key-values for inference # ================================================= - if inference_params and self.layer_number is not None: - assert ( - self.qkv_format != "thd" - ), "qkv_format == thd is not supported for an inference with KV-cache!" - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] + if inference_params is not None and self.layer_number not in inference_params.layer_numbers: + inference_params.allocate_memory(self.layer_number) # ====================== # Query, Key, and Value @@ -9056,9 +9485,10 @@ def forward( elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) else: - raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") + raise ValueError(f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE.") - sequence_start = inference_params.sequence_len_offset + # TODO: consider cases where sequences have different seqlens + sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index bf5ca4d98e..4e83abefc9 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -52,6 +52,12 @@ "thd_t2hd", "thd_th2d", "thd_thd_thd", + "paged_kv_bshd_2bshd", + "paged_kv_bshd_2sbhd", + "paged_kv_sbhd_2bshd", + "paged_kv_sbhd_2sbhd", + "paged_kv_thd_2bshd", + "paged_kv_thd_2sbhd", ) LayerTypes = ("encoder", "decoder") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 1932e9feb2..029e4e419d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -50,6 +50,12 @@ "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, + "paged_kv_bshd_2bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_2BSHD, + "paged_kv_bshd_2sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_2SBHD, + "paged_kv_sbhd_2bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_2BSHD, + "paged_kv_sbhd_2sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_2SBHD, + "paged_kv_thd_2bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_2BSHD, + "paged_kv_thd_2sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_2SBHD, } AttnBiasType = { @@ -900,6 +906,8 @@ def fused_attn_fwd( attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, + page_table_k: torch.Tensor = None, + page_table_v: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, @@ -957,6 +965,10 @@ def fused_attn_fwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] + page_table_k: torch.Tensor, default = None + page table for K cache; shape [batch_size, max_pages_per_seq_k] + page_table_v: torch.Tensor, default = None + page table for V cache; shape [batch_size, max_pages_per_seq_v] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_qkv_offset: int, default = META_QKV @@ -1098,6 +1110,8 @@ def fused_attn_fwd( qkv_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, + page_table_k, + page_table_v, d_scale_qkv, d_scale_qkv_offset, d_scale_s, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b039bf2d1b..253ea2bafe 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -110,6 +110,8 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, + const c10::optional page_table_k, + const c10::optional page_table_v, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, const c10::optional scale_S, const int scale_S_offset, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 8088a2b8f1..ea66a7408b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -781,6 +781,8 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, + const c10::optional page_table_k, + const c10::optional page_table_v, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, const c10::optional scale_S, const int scale_S_offset, @@ -808,6 +810,7 @@ std::vector fused_attn_fwd( TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; + TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -879,6 +882,19 @@ std::vector fused_attn_fwd( nullptr, nullptr, nullptr); } + if ((page_table_k.has_value()) && (page_table_v.has_value())) { + auto page_table_k_sizes = page_table_k.value().sizes().vec(); + std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; + auto page_table_v_sizes = page_table_v.value().sizes().vec(); + std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; + te_page_table_k = makeTransformerEngineTensor(page_table_k.value().data_ptr(), + page_table_k_shape, DType::kInt32, + nullptr, nullptr, nullptr); + te_page_table_v = makeTransformerEngineTensor(page_table_v.value().data_ptr(), + page_table_v_shape, DType::kInt32, + nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); @@ -899,7 +915,7 @@ std::vector fused_attn_fwd( nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, + te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -941,7 +957,7 @@ std::vector fused_attn_fwd( nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, + te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py new file mode 100644 index 0000000000..2b62ac2067 --- /dev/null +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Non-Paged KV Cache Manager.""" +from collections import OrderedDict +from typing import List, Optional +import torch + +class NonPagedKVCacheManager: + """ + The non-paged KV cache manager. + """ + def __init__(self, + max_batch_size: int, + max_seqlen: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + is_cuda_graph: bool = False, + ): + """Initialize the KV cache""" + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + self.is_cuda_graph = is_cuda_graph + + # sequences contained in the kv cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # KV cache tuple (k_cache, v_cache) + self.cache = {} + + def allocate_memory(self, layer_number): + """Allocate memory for the KV cache""" + k_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + def step(self, layer_number, k: torch.Tensor, v: torch.Tensor, step_dict: OrderedDict, qkv_format: str): + """ + Update the non-paged KV cache for a given inference iteration. + For more details, please refer to InferenceParams.update_cache(). + + Parameters + ---------- + layer_number: int + The layer number of kv cache to operate on + k: torch.Tensor + The new key tokens for the current iteration + v: torch.Tensor + The new value tokens for the current iteration + step_dict: OrderedDict + The {seq_id: step_len} information for the new inference step + qkv_format: str + The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + The key cache tensor containing previous and the current tokens + v_cache: torch.Tensor + The value cache tensor containing previous and the current tokens + """ + k_cache, v_cache = self.cache[layer_number] + prev_batch_size = len(self.sequences) + batch_size = len(step_dict) + + # Reorder cache + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + unfinished_indices = [i for i,j in enumerate(self.sequences) if j in unfinished_seqs] + finished_indices = [i for i,j in enumerate(self.sequences) if j in finished_seqs] + batch_indices = unfinished_indices + finished_indices \ + + list(range(prev_batch_size, self.max_batch_size)) + new_k_cache = k_cache[batch_indices, :] + new_v_cache = v_cache[batch_indices, :] + new_k_cache = new_k_cache.contiguous() + new_v_cache = new_v_cache.contiguous() + + # Advance unfinished sequences + for i in unfinished_seqs: + self.sequences[i] += 1 + + # Remove finished sequences + for i in finished_seqs: + self.sequences.pop(i) + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for i in new_seqs: + self.sequences[i] = step_dict[i] + + # Copy new key/value tokens to cache + step_lens = list(step_dict.values()) + cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1,batch_size+1)] + for i,seq in enumerate(self.sequences): + seq_s = self.sequences[seq] - step_dict[seq] + seq_e = self.sequences[seq] + if qkv_format == 'bshd': + new_k_cache[i, seq_s:seq_e, :, :] = k[i, :step_dict[seq], :, :] + new_v_cache[i, seq_s:seq_e, :, :] = v[i, :step_dict[seq], :, :] + if qkv_format == 'sbhd': + new_k_cache[i, seq_s:seq_e, :, :] = k[:step_dict[seq], i, :, :] + new_v_cache[i, seq_s:seq_e, :, :] = v[:step_dict[seq], i, :, :] + if qkv_format == 'thd': + new_k_cache[i, seq_s:seq_e, :, :] = k[cu_seqlens[i]:cu_seqlens[i+1], :, :] + new_v_cache[i, seq_s:seq_e, :, :] = v[cu_seqlens[i]:cu_seqlens[i+1], :, :] + self.cache[layer_number] = (new_k_cache, new_v_cache) + + # Return full key/value tensors for attention calculation + if self.is_cuda_graph: + # [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] + return new_k_cache, new_v_cache + else: + # [actual_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] + new_k_cache = new_k_cache[:batch_size].contiguous() + new_v_cache = new_v_cache[:batch_size].contiguous() + return new_k_cache, new_v_cache diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py new file mode 100644 index 0000000000..8591a843d1 --- /dev/null +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Paged KV Cache Manager.""" +from collections import defaultdict, OrderedDict +from typing import List, Optional +import logging + +import torch + + +class Page(object): + """A single page""" + def __init__(self, page_id: int): + self.page_id = page_id + self.allocated = 0 + + def allocate_page(self): + self.allocated = True + + def deallocate_page(self): + self.allocated = False + +class PagedKVCacheManager(object): + """ + Paged KV cache manager. It supports a set of utilities including adding and removing + sequences, and copying new key/value tokens to the cache. Users can overwrite this class + for more custom implementations. + """ + def __init__(self, + total_num_pages: int, + page_size: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + max_batch_size: int, + max_seqlen: int, + head_dim_v: Optional[int] = None, + is_cuda_graph: bool = False, + ): + """Initialize the KV cache""" + self.total_num_pages = total_num_pages + self.page_size = page_size + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.max_pages_per_seq = max_seqlen // self.page_size + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + self.is_cuda_graph = is_cuda_graph + + # sequences contained in the kv cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # kv cache, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # free pages allowed to allocate, [Page(),...] + self.free_pages = [] + # allocated pages, {seq_id: [page_id,...]} + self.allocated_pages = defaultdict(list) + # page table, [batch_size, max_pages_per_seq] + self.page_table = None + + def allocate_memory(self, layer_number): + """Allocate memory for the KV cache""" + k_cache = torch.empty( + self.total_num_pages, self.page_size, self.num_heads, self.head_dim_k, + dtype=self.dtype, device=torch.cuda.current_device()) + v_cache = torch.empty( + self.total_num_pages, self.page_size, self.num_heads, self.head_dim_v, + dtype=self.dtype, device=torch.cuda.current_device()) + self.cache[layer_number] = (k_cache, v_cache) + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + if self.is_cuda_graph: + self.page_table = torch.zeros( + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device='cuda') + + def print_cache(self): + """Print KV cache status""" + used_pages = [self.get_page_count(seq) for seq in self.sequences] + logger = logging.getLogger("PagedAttention") + logger.debug("cache status:") + logger.debug(f" total pages: {self.total_num_pages} (used {sum(used_pages)}, free {len(self.free_pages)})") + logger.debug(f" total sequences: {self.get_sequence_count()}") + for i, seq in enumerate(self.sequences): + logger.debug(f" >> batch index {i}: seq_id {seq}, num_tokens {self.get_sequence_lengths()[i]}, num_pages {self.get_page_count(seq)}, page_list {self.get_page_list(seq)}") + + def get_sequence_count(self): + """Get the total number of sequences in the KV cache""" + return len(self.sequences) + + def get_sequence_lengths(self): + """Get the list of sequence lengths in the KV cache""" + return list(self.sequences.values()) + + def has_free_page(self) -> bool: + """Whether the page pool has any free pages left""" + return len(self.free_pages) > 0 + + def get_page_count(self, seq: int): + """Get the number of pages allocated to a sequence""" + return len(self.allocated_pages[seq]) + + def get_page_list(self, seq: int): + """Get the list of pages allocated to a sequence""" + return [x.page_id for x in self.allocated_pages[seq]] + + def get_page_token_offsets(self, seqlen: int): + """Get the relevant page index and token index for a given sequence length""" + page_offset = seqlen // self.page_size + token_offset = seqlen % self.page_size + return (page_offset, token_offset) + + def get_page_table(self, sequences: List[int]): + """Get the page table, in shape [batch_size, max_pages_per_seq]""" + page_table = torch.Tensor([self.get_page_list(seq) + \ + [0]*(self.max_pages_per_seq-self.get_page_count(seq)) \ + for seq in sequences]).to(dtype=torch.int32, device='cpu') + if self.is_cuda_graph: + self.page_table[:self.get_sequence_count()].copy_(page_table) + else: + self.page_table = page_table.to(device='cuda') + return self.page_table + + def allocate_page(self, seq: int): + """Allocate a new page to a sequence""" + if not self.has_free_page(): + raise RuntimeError("KV cache is full!") + page = self.free_pages.pop(0) + page.allocate_page() + self.allocated_pages[seq].append(page) + + def allocate_sequence(self, seq: int, context_len: int): + """Add a new sequence to the cache""" + num_pages = context_len // self.page_size + if context_len % self.page_size > 0: + num_pages = num_pages + 1 + for _ in range(num_pages): + self.allocate_page(seq) + + def deallocate_sequence(self, seq: int): + """Deallocate all the pages for a sequence""" + for page in self.allocated_pages[seq]: + page.deallocate_page() + if not page.allocated: + self.free_pages.append(page) + self.allocated_pages.pop(seq) + + def step(self, layer_number: int, k: torch.Tensor, v: torch.Tensor, step_dict: OrderedDict, qkv_format: str): + """ + Update the paged KV cache for a given inference iteration. + For more details, please refer to InferenceParams.update_cache(). + + Parameters + ---------- + layer_number: int + The layer number of kv cache to operate on + k: torch.Tensor + A batch of new key tokens for the current iteration + v: torch.Tensor + A batch of new value tokens for the current iteration + step_dict: OrderedDict + The {seq_id: step_len} information for the new inference step + qkv_format: str + The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + The key cache tensor containing previous and the current tokens + v_cache: torch.Tensor + The value cache tensor containing previous and the current tokens + """ + batch_size = len(step_dict) + step_lens = list(step_dict.values()) + cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1,batch_size+1)] + + # Remove finished sequences and advance unfinished sequences + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + for seq in finished_seqs: + self.sequences.pop(seq) + self.deallocate_sequence(seq) + for seq in unfinished_seqs: + if (self.sequences[seq] % self.page_size == 0 + and self.sequences[seq] < self.max_seqlen): + self.allocate_page(seq) + self.sequences[seq] += 1 + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for seq in new_seqs: + self.sequences[seq] = step_dict[seq] + self.allocate_sequence(seq, step_dict[seq]) + + # Copy new key and value tenosrs to the cache + seqlens = list(self.sequences.values()) + packed_k = torch.Tensor([]).to(dtype=k.dtype, device=k.device) + packed_v = torch.Tensor([]).to(dtype=v.dtype, device=v.device) + for i in range(batch_size): + if qkv_format == 'bshd': + packed_k = torch.cat([packed_k, k[i, :step_lens[i], :, :]], dim=0) + packed_v = torch.cat([packed_v, v[i, :step_lens[i], :, :]], dim=0) + if qkv_format == 'sbhd': + packed_k = torch.cat([packed_k, k[:step_lens[i], i, :, :]], dim=0) + packed_v = torch.cat([packed_v, v[:step_lens[i], i, :, :]], dim=0) + if qkv_format == 'thd': + packed_k = k + packed_v = v + k_cache, v_cache = self.cache[layer_number] + for i,seq in enumerate(step_dict.keys()): + page_list = self.get_page_list(seq) + start_page, start_token = self.get_page_token_offsets( + seqlens[i]-step_lens[i]) + end_page, end_token = self.get_page_token_offsets( + seqlens[i]) + if start_page == end_page: + page_id = page_list[start_page] + k_cache[page_id,start_token:end_token,:,:] = \ + packed_k[cu_seqlens[i]:cu_seqlens[i+1],:,:] + v_cache[page_id,start_token:end_token,:,:] = \ + packed_v[cu_seqlens[i]:cu_seqlens[i+1],:,:] + else: + start_offset = 0 + end_offset = 0 + for j in range(start_page, end_page+1): + if not (j == end_page and end_token == 0): + start_token_j = start_token if j == start_page else 0 + end_token_j = end_token if j == end_page else self.page_size + page_id = page_list[start_page] + end_offset = end_token_j - start_token_j + k_cache[page_id,start_token_j:end_token_j,:,:] = \ + packed_k[cu_seqlens[i]+start_offset:cu_seqlens[i]+end_offset,:,:] + v_cache[page_id,start_token_j:end_token_j,:,:] = \ + packed_v[cu_seqlens[i]+start_offset:cu_seqlens[i]+end_offset,:,:] + start_offset = start_offset + end_offset + + # Get page table + page_table = self.get_page_table(list(self.sequences.keys())) + + return k_cache, v_cache, page_table From 06605e56bcf830c2d6b0f8b57ac6570bb4bb028f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:04:10 -0800 Subject: [PATCH 017/239] remove unnecessary change from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index fa5b74cb45..c280049ee8 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1304,7 +1304,7 @@ def _run_transformer_layer( model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 12, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), From 0b2eb887ed175de6d96a53bd565e7a99534d4ea2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 20:51:49 -0800 Subject: [PATCH 018/239] test_fused_attn pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 49 +++++++++---------- .../fused_attn_f16_arbitrary_seqlen.cu | 6 ++- transformer_engine/pytorch/attention.py | 12 +++++ 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index c280049ee8..be4d359713 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1984,21 +1984,18 @@ def forward( qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], fp8_dtype_forward, FusedAttnBackend["FP8"], - None, - None, - None, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + d_scale_qkv=fp8_meta["scaling_fwd"].scale_inv, + d_scale_qkv_offset=META_QKV, + d_scale_s=fp8_meta["scaling_fwd"].scale_inv, + d_scale_s_offset=META_S, + q_scale_s=fp8_meta["scaling_fwd"].scale, + q_scale_s_offset=META_S, + q_scale_o=fp8_meta["scaling_fwd"].scale, + q_scale_o_offset=META_O, + amax_s=fp8_meta["scaling_fwd"].amax_history, + amax_s_offset=META_S, + amax_o=fp8_meta["scaling_fwd"].amax_history, + amax_o_offset=META_O, attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, @@ -2070,18 +2067,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], - None, - None, - fwd_scale_inverses[META_QKV], # d_scale_qkv, - fwd_scale_inverses[META_S], # d_scale_s, - fwd_scale_inverses[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + d_scale_qkv=fwd_scale_inverses[META_QKV], + d_scale_s=fwd_scale_inverses[META_S], + d_scale_o=fwd_scale_inverses[META_O], + d_scale_do=ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], + d_scale_dp=ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], + q_scale_s=fwd_scales[META_S], + q_scale_dp=ctx.fp8_meta["scaling_bwd"].scale[META_DP], + q_scale_dqkv=ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], + amax_dp=ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], + amax_dqkv=ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 5ab2452b49..61370be44a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -472,7 +472,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; @@ -634,7 +635,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( k->set_ragged_offset(offset_k); v->set_ragged_offset(offset_v); o->set_ragged_offset(offset_o); + dO->set_ragged_offset(offset_o); } + stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("stats") .set_dim({b, h, s_q, 1}) @@ -667,6 +670,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); + sdpa_backward_options.set_max_total_seq_len_q(s_kv); } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 13cd7ac637..61abecda70 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -483,6 +483,11 @@ def get_attention_backend( # FusedAttention | non-paged/paged | FP16/BF16 # UnfusedDotProductAttention | non-paged/paged | FP32/FP16/BF16 if inference_params is not None: + if context_parallel: + logger.debug("Disabling all backends as KV caching is not supported for context parallelism") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: logger.debug("Disabling all backends as FP8 KV caching is not yet implemented") use_flash_attention = False @@ -494,6 +499,11 @@ def get_attention_backend( "Disabling FusedAttention as paged KV caching requires cuDNN 9.5+" ) use_fused_attention = False + if use_flash_attention and not _use_flash_attn_3 and not _flash_attn_2_5_7_plus: + logger.debug( + "Disabling FlashAttention as paged KV caching requires flash-attn 2.5.7+ or v3" + ) + use_flash_attention = False # Filter: Head dimension if use_flash_attention and head_dim_qk != head_dim_v: @@ -6900,6 +6910,8 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, + None, + None, fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv META_QKV, # d_scale_qkv_offset fp8_meta["scaling_fwd"].scale_inv, # d_scale_s From b0a5da4726024b00e6b717492d3673733757cd4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 05:04:48 +0000 Subject: [PATCH 019/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 280 ++++++++++-------- tests/pytorch/test_numerics.py | 43 +-- .../common/fused_attn/fused_attn.cpp | 10 +- .../fused_attn_f16_arbitrary_seqlen.cu | 140 +++++---- .../fused_attn_f16_arbitrary_seqlen.h | 16 +- .../common/fused_attn/fused_attn_fp8.cu | 14 +- transformer_engine/common/fused_attn/utils.h | 17 +- .../include/transformer_engine/fused_attn.h | 45 +-- transformer_engine/pytorch/attention.py | 262 +++++++++------- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/attention.cu | 43 ++- .../pytorch/kv_cache_manager_non_paged.py | 65 ++-- .../pytorch/kv_cache_manager_paged.py | 128 +++++--- 13 files changed, 624 insertions(+), 442 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index a2fa6e81a5..9b0cc54d92 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -25,45 +25,46 @@ _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() + class Batch(object): def __init__(self): self.batch_size = 0 - self.seq_ids = torch.Tensor([]).to(dtype=torch.bool,device='cpu') - self.ctx_lens = torch.Tensor([]).to(dtype=torch.bool,device='cpu') - self.gen_lens = torch.Tensor([]).to(dtype=torch.bool,device='cpu') + self.seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") self.total_lens = self.ctx_lens + self.gen_lens - self.expected_gen_lens = torch.Tensor([]).to(dtype=torch.bool,device='cpu') - self.finished = torch.Tensor([]).to(dtype=torch.bool,device='cpu') - self.step_lens_q = torch.Tensor([]).to(dtype=torch.int32,device='cpu') + self.expected_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.finished = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.step_lens_q = torch.Tensor([]).to(dtype=torch.int32, device="cpu") def copy(self): new_batch = Batch() - new_batch.batch_size = self.batch_size - new_batch.seq_ids = self.seq_ids - new_batch.ctx_lens = self.ctx_lens - new_batch.gen_lens = self.gen_lens - new_batch.total_lens = self.total_lens - new_batch.expected_gen_lens = self.expected_gen_lens - new_batch.finished = self.finished - new_batch.step_lens_q = self.step_lens_q + new_batch.batch_size = self.batch_size + new_batch.seq_ids = self.seq_ids + new_batch.ctx_lens = self.ctx_lens + new_batch.gen_lens = self.gen_lens + new_batch.total_lens = self.total_lens + new_batch.expected_gen_lens = self.expected_gen_lens + new_batch.finished = self.finished + new_batch.step_lens_q = self.step_lens_q return new_batch - def print(self, logger, header='current batch:'): + def print(self, logger, header="current batch:"): logger.debug(header) - logger.debug(' {:<17s}: {}'.format('batch_size',self.batch_size)) - logger.debug(' {:<17s}: {}'.format('seq_ids',self.seq_ids.tolist())) - logger.debug(' {:<17s}: {}'.format('ctx_lens',self.ctx_lens.tolist())) - logger.debug(' {:<17s}: {}'.format('gen_lens',self.gen_lens.tolist())) - logger.debug(' {:<17s}: {}'.format('total_lens',self.total_lens.tolist())) - logger.debug(' {:<17s}: {}'.format('expected_gen_lens',self.expected_gen_lens.tolist())) - logger.debug(' {:<17s}: {}'.format('finished',self.finished.tolist())) - logger.debug(' {:<17s}: {}'.format('step_lens_q',self.step_lens_q.tolist())) + logger.debug(" {:<17s}: {}".format("batch_size", self.batch_size)) + logger.debug(" {:<17s}: {}".format("seq_ids", self.seq_ids.tolist())) + logger.debug(" {:<17s}: {}".format("ctx_lens", self.ctx_lens.tolist())) + logger.debug(" {:<17s}: {}".format("gen_lens", self.gen_lens.tolist())) + logger.debug(" {:<17s}: {}".format("total_lens", self.total_lens.tolist())) + logger.debug(" {:<17s}: {}".format("expected_gen_lens", self.expected_gen_lens.tolist())) + logger.debug(" {:<17s}: {}".format("finished", self.finished.tolist())) + logger.debug(" {:<17s}: {}".format("step_lens_q", self.step_lens_q.tolist())) def add_new_seqs(self, seq_ids, context_lens, expected_gen_lens): ctx_lens = context_lens[seq_ids] - gen_lens = torch.Tensor([0] * len(seq_ids)).to(dtype=torch.int32,device='cpu') + gen_lens = torch.Tensor([0] * len(seq_ids)).to(dtype=torch.int32, device="cpu") exp_gen_lens = expected_gen_lens[seq_ids] - finished = torch.Tensor([False] * len(seq_ids)).to(dtype=torch.bool,device='cpu') + finished = torch.Tensor([False] * len(seq_ids)).to(dtype=torch.bool, device="cpu") self.batch_size = self.batch_size + len(seq_ids) self.finished = torch.cat([self.finished, finished], dim=0) @@ -74,7 +75,7 @@ def add_new_seqs(self, seq_ids, context_lens, expected_gen_lens): self.gen_lens = gen_lens self.expected_gen_lens = exp_gen_lens else: - self.seq_ids = torch.cat([self.seq_ids, seq_ids],dim=0) + self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0) self.ctx_lens = torch.cat([self.ctx_lens, ctx_lens], dim=0) self.gen_lens = torch.cat([self.gen_lens, gen_lens], dim=0) self.expected_gen_lens = torch.cat([self.expected_gen_lens, exp_gen_lens], dim=0) @@ -82,9 +83,9 @@ def add_new_seqs(self, seq_ids, context_lens, expected_gen_lens): self.step_lens_q = torch.cat([self.step_lens_q, ctx_lens], dim=0) def remove_finished(self): - self.finished = torch.where( - self.gen_lens - self.expected_gen_lens < 0, False, True).to( - dtype=torch.bool,device='cpu') + self.finished = torch.where(self.gen_lens - self.expected_gen_lens < 0, False, True).to( + dtype=torch.bool, device="cpu" + ) self.batch_size = self.finished.logical_not().sum().item() self.seq_ids = self.seq_ids[~self.finished] self.ctx_lens = self.ctx_lens[~self.finished] @@ -93,7 +94,8 @@ def remove_finished(self): self.expected_gen_lens = self.expected_gen_lens[~self.finished] self.gen_lens = self.gen_lens + 1 self.total_lens = self.total_lens + 1 - self.step_lens_q = torch.ones([self.batch_size], dtype=torch.int32, device='cpu') + self.step_lens_q = torch.ones([self.batch_size], dtype=torch.int32, device="cpu") + param_types = [torch.float16] if is_bf16_compatible(): @@ -102,13 +104,15 @@ def remove_finished(self): model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias "infer_0": ModelConfig(4, 16, 16, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=8), - "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), - } + "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), +} + +qkv_formats = ["bshd", "sbhd", "thd"] -qkv_formats = ['bshd', 'sbhd', 'thd'] def to_pretty_string(x: torch.Tensor): - return '['+','.join(['{:>3s}'.format(str(i)) for i in x.tolist()])+']' + return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]" + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @@ -118,16 +122,16 @@ def to_pretty_string(x: torch.Tensor): @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() - logger = logging.getLogger('test_paged_attn') + logger = logging.getLogger("test_paged_attn") config = model_configs_infer[model] layer_number = 1 - inference_params_qkv_format = 'bshd' + inference_params_qkv_format = "bshd" if is_paged: - qkv_layout = "paged_kv_"+inference_params_qkv_format+'_2'+inference_params_qkv_format + qkv_layout = "paged_kv_" + inference_params_qkv_format + "_2" + inference_params_qkv_format else: - qkv_layout = '_'.join([inference_params_qkv_format]*3) + qkv_layout = "_".join([inference_params_qkv_format] * 3) available_backends, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, @@ -161,9 +165,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): max_seqlen_kv_roundup = max_seqlen_kv if is_paged: - max_seqlen_kv_roundup = int((max_seqlen_kv + page_size - 1)//page_size * page_size) + max_seqlen_kv_roundup = int((max_seqlen_kv + page_size - 1) // page_size * page_size) else: - max_seqlen_kv_roundup = int((max_seqlen_kv + 63)//64 * 64) + max_seqlen_kv_roundup = int((max_seqlen_kv + 63) // 64 * 64) cache_size = max_batch_size * max_seqlen_kv_roundup total_num_pages = int(cache_size / page_size) @@ -173,16 +177,20 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): max_gen_len = int(max_seqlen_kv * gen_ratio) # context lengths in Uniform distribution - context_lens = torch.randint(1, max_context_len, [total_requests], dtype=torch.int32, device='cpu') + context_lens = torch.randint( + 1, max_context_len, [total_requests], dtype=torch.int32, device="cpu" + ) # generation lengths in Exponential distribution - gen_dist = Exponential(1/max_gen_len) + gen_dist = Exponential(1 / max_gen_len) gen_lens = gen_dist.sample((total_requests,)) - gen_lens = torch.where(gen_lens>max_gen_len, max_gen_len, gen_lens).to(dtype=torch.int32, device='cpu') + gen_lens = torch.where(gen_lens > max_gen_len, max_gen_len, gen_lens).to( + dtype=torch.int32, device="cpu" + ) # arrival times in Poisson distribution rate = torch.randint(1, max_batch_size, [1]).item() interval_dist = Exponential(rate) arrival_intervals = interval_dist.sample((total_requests,)) - arrival_times = torch.cumsum(arrival_intervals,dim=0).to(dtype=torch.int32, device='cpu') + arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") last_arrival = arrival_times.max().item() logger.info("Simulation:") @@ -208,31 +216,37 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): layer_number=layer_number, attention_dropout=0.0, attn_mask_type="causal", - qkv_format='bshd', + qkv_format="bshd", ) .cuda() .eval() ) q = 0.1 * torch.randn( - (total_requests, max_seqlen_kv, config.num_heads, config.head_dim_qk), - dtype=dtype, device="cuda") + (total_requests, max_seqlen_kv, config.num_heads, config.head_dim_qk), + dtype=dtype, + device="cuda", + ) k = 0.1 * torch.randn( - (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk), - dtype=dtype, device="cuda") + (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk), + dtype=dtype, + device="cuda", + ) v = 0.1 * torch.randn( - (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_v), - dtype=dtype, device="cuda") + (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_v), + dtype=dtype, + device="cuda", + ) logger.info("") logger.info("=== Generating all tokens at once ===") - request_delays = torch.zeros([total_requests],dtype=torch.int32,device='cpu') + request_delays = torch.zeros([total_requests], dtype=torch.int32, device="cpu") full_output = model( - query_layer=q, - key_layer=k, - value_layer=v, - qkv_format='bshd', - attn_mask_type="causal", + query_layer=q, + key_layer=k, + value_layer=v, + qkv_format="bshd", + attn_mask_type="causal", ) t = 1 @@ -244,26 +258,26 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): logger.info("") logger.info("=== Generating one token at a time ===") inference_params = InferenceParams( - max_batch_size=max_batch_size, - max_seqlen_kv=max_seqlen_kv_roundup, - num_heads_kv=config.num_gqa_groups, - head_dim_k=config.head_dim_qk, - head_dim_v=config.head_dim_v, - dtype=dtype, - is_paged=is_paged, - page_size=page_size, - total_num_pages=total_num_pages, - is_cuda_graph=is_cuda_graph, - num_heads_q=config.num_heads, - head_dim_q=config.head_dim_qk, - ) + max_batch_size=max_batch_size, + max_seqlen_kv=max_seqlen_kv_roundup, + num_heads_kv=config.num_gqa_groups, + head_dim_k=config.head_dim_qk, + head_dim_v=config.head_dim_v, + dtype=dtype, + is_paged=is_paged, + page_size=page_size, + total_num_pages=total_num_pages, + is_cuda_graph=is_cuda_graph, + num_heads_q=config.num_heads, + head_dim_q=config.head_dim_qk, + ) inference_params.allocate_memory(layer_number) inference_params.print() - request_delays = torch.zeros([total_requests],dtype=torch.int32,device='cpu') + request_delays = torch.zeros([total_requests], dtype=torch.int32, device="cpu") t = 0 prev = Batch() - delayed_seq_ids = torch.Tensor().to(dtype=torch.int32,device='cpu') + delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu") while True: logger.debug(f"time step {t}") cur = prev.copy() @@ -277,7 +291,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): allowed_num_new_seqs = max_batch_size - cur.batch_size else: allowed_num_new_seqs = 0 if cur.batch_size > 0 else max_batch_size - queuing_seq_ids = torch.cat([delayed_seq_ids, arrived_seq_ids],dim=0) + queuing_seq_ids = torch.cat([delayed_seq_ids, arrived_seq_ids], dim=0) logger.debug(f"arrived seq_ids: {to_pretty_string(arrived_seq_ids)}") logger.debug(f"previously delayed seq_ids: {to_pretty_string(delayed_seq_ids)}") logger.debug(f"allowed num of new sequences: {allowed_num_new_seqs}") @@ -304,50 +318,77 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): continue if not is_cuda_graph: - max_seqlen_q_infer = int((cur.step_lens_q.max().item() + 63)//64 * 64) + max_seqlen_q_infer = int((cur.step_lens_q.max().item() + 63) // 64 * 64) else: max_seqlen_q_infer = max_seqlen_kv_roundup # create incremental input - if qkv_format == 'thd': - incremental_q = torch.Tensor().to(dtype=dtype, device='cuda') - incremental_k = torch.Tensor().to(dtype=dtype, device='cuda') - incremental_v = torch.Tensor().to(dtype=dtype, device='cuda') - for i,seq in enumerate(cur.seq_ids): - start = (cur.total_lens[i]-cur.step_lens_q[i]).item() + if qkv_format == "thd": + incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") + incremental_k = torch.Tensor().to(dtype=dtype, device="cuda") + incremental_v = torch.Tensor().to(dtype=dtype, device="cuda") + for i, seq in enumerate(cur.seq_ids): + start = (cur.total_lens[i] - cur.step_lens_q[i]).item() end = cur.total_lens[i].item() - incremental_q = torch.cat([incremental_q, - q[seq, start:end, :, :]],dim=0) - incremental_k = torch.cat([incremental_k, - k[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_qk)], dim=0) - incremental_v = torch.cat([incremental_v, - v[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_v)], dim=0) + incremental_q = torch.cat([incremental_q, q[seq, start:end, :, :]], dim=0) + incremental_k = torch.cat( + [ + incremental_k, + k[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_qk), + ], + dim=0, + ) + incremental_v = torch.cat( + [ + incremental_v, + v[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_v), + ], + dim=0, + ) else: incremental_q = torch.zeros( - cur.batch_size, max_seqlen_q_infer, config.num_heads, config.head_dim_qk, - dtype=dtype, device='cuda') + cur.batch_size, + max_seqlen_q_infer, + config.num_heads, + config.head_dim_qk, + dtype=dtype, + device="cuda", + ) incremental_k = torch.zeros( - cur.batch_size, max_seqlen_q_infer, config.num_gqa_groups, config.head_dim_qk, - dtype=dtype, device='cuda') + cur.batch_size, + max_seqlen_q_infer, + config.num_gqa_groups, + config.head_dim_qk, + dtype=dtype, + device="cuda", + ) incremental_v = torch.zeros( - cur.batch_size, max_seqlen_q_infer, config.num_gqa_groups, config.head_dim_v, - dtype=dtype, device='cuda') - for i,seq in enumerate(cur.seq_ids): - start = (cur.total_lens[i]-cur.step_lens_q[i]).item() + cur.batch_size, + max_seqlen_q_infer, + config.num_gqa_groups, + config.head_dim_v, + dtype=dtype, + device="cuda", + ) + for i, seq in enumerate(cur.seq_ids): + start = (cur.total_lens[i] - cur.step_lens_q[i]).item() end = cur.total_lens[i].item() - incremental_q[i, :cur.step_lens_q[i], :, :] = q[seq, start:end, :, :] - incremental_k[i, :cur.step_lens_q[i], :, :] = k[seq, start:end, :, :] - incremental_v[i, :cur.step_lens_q[i], :, :] = v[seq, start:end, :, :] - if qkv_format == 'sbhd': + incremental_q[i, : cur.step_lens_q[i], :, :] = q[seq, start:end, :, :] + incremental_k[i, : cur.step_lens_q[i], :, :] = k[seq, start:end, :, :] + incremental_v[i, : cur.step_lens_q[i], :, :] = v[seq, start:end, :, :] + if qkv_format == "sbhd": incremental_q, incremental_k, incremental_v = [ - x.transpose(0,1) for x in [incremental_q, incremental_k, incremental_v]] + x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] + ] cu_seqlens_q = torch.zeros(cur.batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1:cur.batch_size+1] = torch.cumsum(cur.step_lens_q, dim=0) + cu_seqlens_q[1 : cur.batch_size + 1] = torch.cumsum(cur.step_lens_q, dim=0) cu_seqlens_kv = torch.zeros(cur.batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv[1:cur.batch_size+1] = torch.cumsum(cur.total_lens, dim=0) + cu_seqlens_kv[1 : cur.batch_size + 1] = torch.cumsum(cur.total_lens, dim=0) - inference_params.step_dict = OrderedDict(zip(cur.seq_ids.tolist(), cur.step_lens_q.tolist())) + inference_params.step_dict = OrderedDict( + zip(cur.seq_ids.tolist(), cur.step_lens_q.tolist()) + ) line_output = model( query_layer=incremental_q, @@ -374,25 +415,28 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): torch.half: 4e-3, torch.bfloat16: 1e-2, } - for i,seq in enumerate(cur.seq_ids): - if qkv_format == 'bshd': + for i, seq in enumerate(cur.seq_ids): + if qkv_format == "bshd": torch.testing.assert_close( - full_output[seq,cur.total_lens[i]-1,:], - line_output[i,cur.step_lens_q[i]-1,:], - atol = tols[dtype], - rtol = tols[dtype]) - if qkv_format == 'sbhd': + full_output[seq, cur.total_lens[i] - 1, :], + line_output[i, cur.step_lens_q[i] - 1, :], + atol=tols[dtype], + rtol=tols[dtype], + ) + if qkv_format == "sbhd": torch.testing.assert_close( - full_output[seq,cur.total_lens[i]-1,:], - line_output[cur.step_lens_q[i]-1,i,:], - atol = tols[dtype], - rtol = tols[dtype]) - if qkv_format == 'thd': + full_output[seq, cur.total_lens[i] - 1, :], + line_output[cur.step_lens_q[i] - 1, i, :], + atol=tols[dtype], + rtol=tols[dtype], + ) + if qkv_format == "thd": torch.testing.assert_close( - full_output[seq,cur.total_lens[i]-1,:], - line_output[cu_seqlens_q[i+1]-1,:], - atol = tols[dtype], - rtol = tols[dtype]) + full_output[seq, cur.total_lens[i] - 1, :], + line_output[cu_seqlens_q[i + 1] - 1, :], + atol=tols[dtype], + rtol=tols[dtype], + ) prev = cur.copy() t += 1 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index aef3fab070..1dc0b56d99 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1939,7 +1939,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("backend", backends_inference) @pytest.mark.parametrize("is_paged", [False, True]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged, is_cuda_graph): +def test_kv_cache_accuracy( + dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged, is_cuda_graph +): reset_rng_states() if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: @@ -1998,22 +2000,22 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ) inference_params = InferenceParams( - max_batch_size=B_max, - max_seqlen_kv=S_max, - num_heads_kv=H, - head_dim_k=head_size, - dtype=dtype, - is_paged=is_paged, - total_num_pages=4, - page_size=256, - is_cuda_graph=is_cuda_graph, - num_heads_q=H, - head_dim_q=head_size, - ) + max_batch_size=B_max, + max_seqlen_kv=S_max, + num_heads_kv=H, + head_dim_k=head_size, + dtype=dtype, + is_paged=is_paged, + total_num_pages=4, + page_size=256, + is_cuda_graph=is_cuda_graph, + num_heads_q=H, + head_dim_q=head_size, + ) rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") - inference_params.step_dict = OrderedDict(zip(list(range(B)), [1]*B)) + inference_params.step_dict = OrderedDict(zip(list(range(B)), [1] * B)) input = torch.randn((S, B, D), dtype=dtype, device="cuda") if input_format == "bshd": @@ -2034,28 +2036,29 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - seqlens_kv = (i+1) * torch.ones(B, dtype=torch.int32, device="cuda") + seqlens_kv = (i + 1) * torch.ones(B, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(B + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) mask_type = "padding" - kwargs={} + kwargs = {} if module == "TransformerLayer": - kwargs['self_attn_mask_type']=mask_type + kwargs["self_attn_mask_type"] = mask_type else: - kwargs['attn_mask_type']=mask_type + kwargs["attn_mask_type"] = mask_type line_output = model( hidden_states=incremental_input, inference_params=inference_params, rotary_pos_emb=rotary_freqs if use_RoPE else None, **kwargs, - max_seqlen_q=1, max_seqlen_kv=S, + max_seqlen_q=1, + max_seqlen_kv=S, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, ) if input_format == "sbhd": - incremental_output[i,:,:] = line_output.view(B, D) + incremental_output[i, :, :] = line_output.view(B, D) else: incremental_output[:, i, :] = line_output.view(B, D) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3e74d1c451..cbb7a7c74a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -631,7 +631,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, @@ -713,11 +714,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index c10b60249a..5dd29fc116 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -50,14 +50,16 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, - bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, + int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, - void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, - size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -78,7 +80,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); - bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || + layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); } @@ -103,7 +106,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_kv, d_qk, d_v, - num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, + max_pages_per_seq_k, + max_pages_per_seq_v, bias_b, bias_h, scaling_factor, @@ -130,7 +138,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // bias std::shared_ptr, // seq_q std::shared_ptr, // seq_kv - std::shared_ptr, // page_table_k + std::shared_ptr, // page_table_k std::shared_ptr, // page_table_v std::shared_ptr, // offset_q std::shared_ptr, // offset_k @@ -171,10 +179,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); if (is_paged_kv) { - generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } else { generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); @@ -194,12 +202,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); Q->set_ragged_offset(offset_q); } - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_stride(k_stride)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_stride(v_stride)); + K = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("K").set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("V").set_stride(v_stride)); if (is_paged_kv) { K->set_dim({num_pages_k, hg, page_size_k, d_qk}); V->set_dim({num_pages_v, hg, page_size_v, d_v}); @@ -265,16 +269,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } if (is_paged_kv) { - page_table_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("page_table_k") - .set_dim({b, 1, max_pages_per_seq_k, 1}) - .set_stride({{max_pages_per_seq_k, max_pages_per_seq_v, 1, 1}}) - .set_data_type(fe::DataType_t::INT32)); - page_table_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("page_table_v") - .set_dim({b, 1, max_pages_per_seq_v, 1}) - .set_stride({{max_pages_per_seq_v, max_pages_per_seq_v, 1, 1}}) - .set_data_type(fe::DataType_t::INT32)); + page_table_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_k") + .set_dim({b, 1, max_pages_per_seq_k, 1}) + .set_stride({{max_pages_per_seq_k, max_pages_per_seq_v, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + page_table_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_v") + .set_dim({b, 1, max_pages_per_seq_v, 1}) + .set_stride({{max_pages_per_seq_v, max_pages_per_seq_v, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); sdpa_options.set_paged_attention_k_table(page_table_k); sdpa_options.set_paged_attention_v_table(page_table_v); sdpa_options.set_paged_attention_max_seq_len_kv(static_cast(s_kv)); @@ -309,9 +315,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}); + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); if (is_ragged && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -334,8 +338,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto page_table_tuple = - is_paged_kv ? std::make_tuple(page_table_k, page_table_v) : std::make_tuple(nullptr, nullptr); + auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) + : std::make_tuple(nullptr, nullptr); auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) : std::make_tuple(nullptr, nullptr, nullptr, nullptr); auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) @@ -350,16 +354,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, - padding_tuple, page_table_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, + bias_tuple, padding_tuple, page_table_tuple, + offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_k, - offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, + offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -486,7 +490,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); - bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || + layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); } @@ -514,7 +519,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( s_kv, d_qk, d_v, - 0,0,0,0,0,0, + 0, + 0, + 0, + 0, + 0, + 0, bias_b, bias_h, scaling_factor, @@ -987,11 +997,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0,0,0,0,0,0, bias_b, bias_h, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, + devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, + handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1203,11 +1214,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, 0,0,0,0,0,0, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, nullptr, nullptr, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, nullptr, nullptr, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1319,13 +1331,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, - size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1419,11 +1433,13 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index cf6b2664bb..a75730a421 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -61,13 +61,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, - size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index b77e954f4a..c768c9a499 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1681,7 +1681,12 @@ void fused_attn_fp8_fwd_impl_v1( s_kv, d, d, - 0,0,0,0,0,0, + 0, + 0, + 0, + 0, + 0, + 0, bias_b, bias_h, scaling_factor, @@ -1986,7 +1991,12 @@ void fused_attn_fp8_bwd_impl_v1( s_kv, d, d, - 0,0,0,0,0,0, + 0, + 0, + 0, + 0, + 0, + 0, bias_b, bias_h, scaling_factor, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 4b39d2c182..0aeb8672ba 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -114,13 +114,16 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, - dropoutProbability, layout, mask_type, window_size_left, window_size_right, - deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, - rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, - rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, + window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, + rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, + rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, + rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 12e96f6d0a..24722b10e5 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -28,27 +28,27 @@ extern "C" { * different lengths in a batch. `Paged_KV`-based layouts are used for paged attention. */ enum NVTE_QKV_Layout { - NVTE_SB3HD = 0, /*!< SB3HD layout */ - NVTE_SBH3D = 1, /*!< SBH3D layout */ - NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ - NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ - NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ - NVTE_BS3HD = 5, /*!< BS3HD layout */ - NVTE_BSH3D = 6, /*!< BSH3D layout */ - NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ - NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ - NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ - NVTE_T3HD = 10, /*!< T3HD layout */ - NVTE_TH3D = 11, /*!< TH3D layout */ - NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ - NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ - NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ - NVTE_Paged_KV_BSHD_2BSHD = 15, /*!< Paged_KV_BSHD_2BSHD layout */ - NVTE_Paged_KV_BSHD_2SBHD = 16, /*!< Paged_KV_BSHD_2SBHD layout */ - NVTE_Paged_KV_SBHD_2BSHD = 17, /*!< Paged_KV_SBHD_2BSHD layout */ - NVTE_Paged_KV_SBHD_2SBHD = 18, /*!< Paged_KV_SBHD_2SBHD layout */ - NVTE_Paged_KV_THD_2BSHD = 19, /*!< Paged_KV_THD_2BSHD layout */ - NVTE_Paged_KV_THD_2SBHD = 20, /*!< Paged_KV_THD_2SBHD layout */ + NVTE_SB3HD = 0, /*!< SB3HD layout */ + NVTE_SBH3D = 1, /*!< SBH3D layout */ + NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ + NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ + NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ + NVTE_BS3HD = 5, /*!< BS3HD layout */ + NVTE_BSH3D = 6, /*!< BSH3D layout */ + NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ + NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ + NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ + NVTE_T3HD = 10, /*!< T3HD layout */ + NVTE_TH3D = 11, /*!< TH3D layout */ + NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ + NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ + NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_Paged_KV_BSHD_2BSHD = 15, /*!< Paged_KV_BSHD_2BSHD layout */ + NVTE_Paged_KV_BSHD_2SBHD = 16, /*!< Paged_KV_BSHD_2SBHD layout */ + NVTE_Paged_KV_SBHD_2BSHD = 17, /*!< Paged_KV_SBHD_2BSHD layout */ + NVTE_Paged_KV_SBHD_2SBHD = 18, /*!< Paged_KV_SBHD_2SBHD layout */ + NVTE_Paged_KV_THD_2BSHD = 19, /*!< Paged_KV_THD_2BSHD layout */ + NVTE_Paged_KV_THD_2SBHD = 20, /*!< Paged_KV_THD_2SBHD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -477,7 +477,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d37a7ab049..c116f567c6 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -484,7 +484,9 @@ def get_attention_backend( # UnfusedDotProductAttention | non-paged/paged | FP32/FP16/BF16 if inference_params is not None: if context_parallel: - logger.debug("Disabling all backends as KV caching is not supported for context parallelism") + logger.debug( + "Disabling all backends as KV caching is not supported for context parallelism" + ) use_flash_attention = False use_fused_attention = False use_unfused_attention = False @@ -495,9 +497,7 @@ def get_attention_backend( use_unfused_attention = False if inference_params.is_paged: if use_fused_attention and cudnn_version < (9, 5, 0): - logger.debug( - "Disabling FusedAttention as paged KV caching requires cuDNN 9.5+" - ) + logger.debug("Disabling FusedAttention as paged KV caching requires cuDNN 9.5+") use_fused_attention = False if use_flash_attention and not _use_flash_attn_3 and not _flash_attn_2_5_7_plus: logger.debug( @@ -535,7 +535,9 @@ def get_attention_backend( use_fused_attention = False # Filter: QKV layout - qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) + qkv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) if qkv_format == "thd": if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") @@ -1041,27 +1043,29 @@ class InferenceParams: # pylint: disable=too-few-public-methods page size in number of tokens if is_paged = True. """ - def __init__(self, - max_batch_size: int, - max_seqlen_kv: int, - num_heads_kv: int, - head_dim_k: int, - dtype: torch.dtype, - head_dim_v: Optional[int] = None, - is_paged: bool = False, - total_num_pages: Optional[int] = None, - page_size: Optional[int] = None, - is_cuda_graph: bool = False, - num_heads_q: Optional[int] = None, - head_dim_q: Optional[int] = None, - ): + def __init__( + self, + max_batch_size: int, + max_seqlen_kv: int, + num_heads_kv: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + is_paged: bool = False, + total_num_pages: Optional[int] = None, + page_size: Optional[int] = None, + is_cuda_graph: bool = False, + num_heads_q: Optional[int] = None, + head_dim_q: Optional[int] = None, + ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv self.num_heads_kv = num_heads_kv self.head_dim_k = head_dim_k - assert ( - dtype in [torch.float32, torch.float16, torch.bfloat16] - ), "Supported InferenceParams.dtype = {torch.float32, torch.float16, torch.bfloat16}. Found {dtype}." + assert dtype in [torch.float32, torch.float16, torch.bfloat16], ( + "Supported InferenceParams.dtype = {torch.float32, torch.float16, torch.bfloat16}." + " Found {dtype}." + ) self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k self.is_paged = is_paged @@ -1077,12 +1081,18 @@ def __init__(self, dtype=self.dtype, head_dim_v=self.head_dim_v, is_cuda_graph=self.is_cuda_graph, - ) + ) else: assert page_size is not None, "page_size is required when is_paged=True!" assert total_num_pages is not None, "total_num_pages is required when is_paged=True!" self.page_size = page_size - self.max_seqlen_kv = self.max_seqlen_kv if self.max_seqlen_kv >= self.page_size else int((self.max_seqlen_kv + self.page_size -1)//self.page_size * self.page_size) + self.max_seqlen_kv = ( + self.max_seqlen_kv + if self.max_seqlen_kv >= self.page_size + else int( + (self.max_seqlen_kv + self.page_size - 1) // self.page_size * self.page_size + ) + ) self.total_num_pages = total_num_pages self.cache_manager = PagedKVCacheManager( total_num_pages=self.total_num_pages, @@ -1094,7 +1104,7 @@ def __init__(self, max_seqlen=self.max_seqlen_kv, head_dim_v=self.head_dim_v, is_cuda_graph=self.is_cuda_graph, - ) + ) if self.is_cuda_graph: assert num_heads_q is not None, "num_heads_q is required when is_cuda_graph=True!" @@ -1103,7 +1113,7 @@ def __init__(self, self.head_dim_q = head_dim_q # memory format for the cache; at the moment, only 'bshd' is supported - self.qkv_format = 'bshd' + self.qkv_format = "bshd" # layer numbers that we have kv cache for self.layer_numbers = [] # sequence ids that are stored in the cache @@ -1137,7 +1147,6 @@ def print(self): logger.debug(f" head_dim: k: {self.head_dim_k}, v: {self.head_dim_v}") logger.debug(f" layer_numbers: {self.layer_numbers}") - def allocate_memory(self, layer_number): """ Allocate memory for the KV cache for the layer #layer_number. @@ -1177,12 +1186,13 @@ def allocate_memory(self, layer_number): device=torch.cuda.current_device(), ) - def reshape_and_copy_q(self, - q: torch.Tensor, - source_qkv_format: str, - target_qkv_format: str, - layer_number: Optional[int] = None, - ): + def reshape_and_copy_q( + self, + q: torch.Tensor, + source_qkv_format: str, + target_qkv_format: str, + layer_number: Optional[int] = None, + ): """ Convert the new query tokens from 'source_qkv_format' to 'target_qkv_format', so that it is consistent with the KV cache format. At the moment, only 'bshd' format @@ -1191,22 +1201,29 @@ def reshape_and_copy_q(self, """ actual_batch_size = len(self.step_dict) seqlens_q = list(self.step_dict.values()) - cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size+1)] - batch_wide_max_seqlen_q = int((max(seqlens_q) + 63)//64 * 64) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + batch_wide_max_seqlen_q = int((max(seqlens_q) + 63) // 64 * 64) if not self.is_cuda_graph: - if source_qkv_format == 'bshd': + if source_qkv_format == "bshd": q = q.contiguous() - if source_qkv_format == 'sbhd': - q = q.transpose(0,1).contiguous() - if source_qkv_format == 'thd': + if source_qkv_format == "sbhd": + q = q.transpose(0, 1).contiguous() + if source_qkv_format == "thd": padded_q = torch.zeros( - actual_batch_size, batch_wide_max_seqlen_q, q.shape[-2], q.shape[-1], - dtype=q.dtype, device='cuda') + actual_batch_size, + batch_wide_max_seqlen_q, + q.shape[-2], + q.shape[-1], + dtype=q.dtype, + device="cuda", + ) for i in range(actual_batch_size): - padded_q[i, :seqlens_q[i], :, :] = q[cu_seqlens_q[i]:cu_seqlens_q[i+1], :, :] + padded_q[i, : seqlens_q[i], :, :] = q[ + cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : + ] q = padded_q - if source_qkv_format in ['bshd', 'sbhd']: + if source_qkv_format in ["bshd", "sbhd"]: self.max_seqlen_q = q.shape[1] else: self.max_seqlen_q = batch_wide_max_seqlen_q @@ -1219,16 +1236,22 @@ def reshape_and_copy_q(self, ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" q_buffer = self.q_buffer[layer_number] for i in range(actual_batch_size): - if source_qkv_format == 'bshd': - q_buffer[i, :seqlens_q[i], :, :] = q[i, :seqlens_q[i], :, :] - if source_qkv_format == 'sbhd': - q_buffer[i, :seqlens_q[i], :, :] = q[:seqlens_q[i], i, :, :] - if source_qkv_format == 'thd': - q_buffer[i, :seqlens_q[i], :, :] = q[cu_seqlens_q[i]:cu_seqlens_q[i+1], :, :] - q_buffer[i, seqlens_q[i]:, :, :].fill_(0) + if source_qkv_format == "bshd": + q_buffer[i, : seqlens_q[i], :, :] = q[i, : seqlens_q[i], :, :] + if source_qkv_format == "sbhd": + q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] + if source_qkv_format == "thd": + q_buffer[i, : seqlens_q[i], :, :] = q[ + cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : + ] + q_buffer[i, seqlens_q[i] :, :, :].fill_(0) - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]]*(self.max_batch_size - actual_batch_size) - self.cu_seqlens_q_buffer.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device='cpu')) + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * ( + self.max_batch_size - actual_batch_size + ) + self.cu_seqlens_q_buffer.copy_( + torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") + ) # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] return q_buffer @@ -1266,37 +1289,42 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): k_cache[page_table.flatten()], "(b npages) page_size ... -> b (npages page_size) ...", b=batch_size, - ) + ) new_v_cache = rearrange( v_cache[page_table.flatten()], "(b npages) page_size ... -> b (npages page_size) ...", b=batch_size, - ) + ) for i in range(actual_batch_size): - new_k_cache[i, self.seqlens[i]:,:,:].fill_(0) - new_v_cache[i, self.seqlens[i]:,:,:].fill_(0) - if qkv_format == 'bshd': + new_k_cache[i, self.seqlens[i] :, :, :].fill_(0) + new_v_cache[i, self.seqlens[i] :, :, :].fill_(0) + if qkv_format == "bshd": new_k_cache = new_k_cache.contiguous() new_v_cache = new_v_cache.contiguous() - if qkv_format == 'sbhd': - new_k_cache = new_k_cache.transpose(0,1).contiguous() - new_v_cache = new_v_cache.transpose(0,1).contiguous() - if qkv_format == 'thd': - packed_k_cache = torch.Tensor().to(dtype=k_cache.dtype,device=k_cache.device) - packed_v_cache = torch.Tensor().to(dtype=v_cache.dtype,device=v_cache.device) + if qkv_format == "sbhd": + new_k_cache = new_k_cache.transpose(0, 1).contiguous() + new_v_cache = new_v_cache.transpose(0, 1).contiguous() + if qkv_format == "thd": + packed_k_cache = torch.Tensor().to(dtype=k_cache.dtype, device=k_cache.device) + packed_v_cache = torch.Tensor().to(dtype=v_cache.dtype, device=v_cache.device) for i in range(batch_size): - packed_k_cache = torch.cat([packed_k_cache, new_k_cache[i,:self.seqlens[i],:,:]], dim=0) - packed_v_cache = torch.cat([packed_v_cache, new_v_cache[i,:self.seqlens[i],:,:]], dim=0) + packed_k_cache = torch.cat( + [packed_k_cache, new_k_cache[i, : self.seqlens[i], :, :]], dim=0 + ) + packed_v_cache = torch.cat( + [packed_v_cache, new_v_cache[i, : self.seqlens[i], :, :]], dim=0 + ) new_k_cache = packed_k_cache.contiguous() new_v_cache = packed_v_cache.contiguous() return new_k_cache, new_v_cache - def update_cache(self, - layer_number: int, - k: torch.Tensor, - v: torch.Tensor, - qkv_format: str, - ): + def update_cache( + self, + layer_number: int, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str, + ): """ Update KV cache with the new key/value tokens for a given inference iteration. @@ -1364,9 +1392,13 @@ def update_cache(self, if self.is_cuda_graph: actual_batch_size = len(self.seqlens) - cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size+1)] - cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]]*(self.max_batch_size - actual_batch_size) - self.cu_seqlens_kv_buffer.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device='cpu')) + cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size + 1)] + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + self.max_batch_size - actual_batch_size + ) + self.cu_seqlens_kv_buffer.copy_( + torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") + ) # k_cache and v_cache are in InferenceParams.qkv_format format return k_cache, v_cache, page_table @@ -5100,9 +5132,13 @@ def forward( qkv_layout in QKVLayouts ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" - qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) + qkv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) if inference_params is not None and inference_params.is_paged: - key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number, qkv_format) + key_layer, value_layer = inference_params.convert_paged_to_nonpaged( + self.layer_number, qkv_format + ) if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now @@ -5667,7 +5703,9 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) + qkv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if qkv_format == "sbhd": @@ -5713,7 +5751,9 @@ def forward( if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" - if inference_params is None or (inference_params is not None and not inference_params.is_paged): + if inference_params is None or ( + inference_params is not None and not inference_params.is_paged + ): # [b * s, h, d] query_layer, key_layer, value_layer = [ x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) @@ -5741,20 +5781,28 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) - cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) + cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices( + attention_mask[1] + ) else: indices_q = get_indices(max_seqlen_q, cu_seqlens_q) indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) query_layer = PackTensors.apply(indices_q, query_layer) - key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) + key_layer, value_layer = PackTensors.apply( + indices_kv, key_layer, value_layer + ) else: # [b * s, h, d] - query_layer = query_layer.reshape(query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:]) + query_layer = query_layer.reshape( + query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:] + ) if cu_seqlens_q is None: assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask if self.attention_type == "self" else attention_mask[0]) + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices( + attention_mask if self.attention_type == "self" else attention_mask[0] + ) else: indices_q = get_indices(max_seqlen_q, cu_seqlens_q) query_layer = PackTensors.apply(indices_q, query_layer) @@ -6023,7 +6071,7 @@ def forward( fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) assert ( qkv_group == 1 ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." @@ -6418,7 +6466,7 @@ def forward( q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) assert qkv_group == 2, ( "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " f"but found {qkv_layout}." @@ -6876,7 +6924,7 @@ def forward( q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) if qkv_group == 1: dim = qkv_layout.find("3") qkv = _combine_tensors([q, k, v], dim) @@ -6890,7 +6938,7 @@ def forward( q_fp8 = cast_to_fp8( q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward ).view(q.shape) - dim = qkv_layout.replace("paged_kv_","").split("_")[1].find("2") + dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv_fp8 = cast_to_fp8( @@ -6967,7 +7015,7 @@ def forward( if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate if is_input_fp8: - qkv_group = len(qkv_layout.replace("paged_kv_","").split("_")) + qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) if qkv_group == 1: dim = qkv_layout.find("3") qkv = _combine_tensors([q, k, v], dim) @@ -6989,7 +7037,7 @@ def forward( fp8_dtype_forward, TE_DType[q.dtype], ).view(q.shape) - dim = qkv_layout.replace("paged_kv_","").split("_")[1].find("2") + dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv_no_fp8 = cast_from_fp8( @@ -7269,7 +7317,7 @@ def backward(ctx, d_out): dtype=d_out_f8tensor.dtype, ) else: - qkv_group = len(ctx.qkv_layout.replace("paged_kv_","").split("_")) + qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) if qkv_group == 1: dim = ctx.qkv_layout.find("3") dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim) @@ -7293,7 +7341,7 @@ def backward(ctx, d_out): fp8_dtype_backward, ctx.qkv_dtype, ).view(dq_fp8.shape) - dim = ctx.qkv_layout.replace("paged_kv_","").split("_")[1].find("2") + dim = ctx.qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim) dkv_c_fp8 = dkv_fp8.view( -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] @@ -7557,7 +7605,9 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join([i for i in qkv_layout.replace("paged_kv_","").split("_")[0] if i.isalpha()]) + qkv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) if qkv_format in ["sbhd", "bshd"]: if qkv_format == "sbhd": @@ -8392,8 +8442,9 @@ def forward( # force tensors to be contiguous if not already query_layer, key_layer, value_layer = [ - x.contiguous() if not x.is_contiguous() else x for x in [ - query_layer, key_layer, value_layer]] + x.contiguous() if not x.is_contiguous() else x + for x in [query_layer, key_layer, value_layer] + ] # reshape the query tensor # cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but @@ -8401,16 +8452,17 @@ def forward( # same qkv_format target_qkv_format = inference_params.qkv_format query_layer = inference_params.reshape_and_copy_q( - query_layer, qkv_format, target_qkv_format, self.layer_number) + query_layer, qkv_format, target_qkv_format, self.layer_number + ) - # update KV cache and return the full key/value tensors + # update KV cache and return the full key/value tensors # full key/value tensors are in inference_params.qkv_format format key_layer, value_layer, page_table = inference_params.update_cache( self.layer_number, key_layer, value_layer, qkv_format, - ) + ) # update cu_seqlens tensors if inference_params.is_cuda_graph: @@ -8429,7 +8481,11 @@ def forward( assert ( key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups_per_partition} heads! Found {key_layer.shape[-2]} in key_layer and {value_layer.shape[-2]} in value_layer." + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads! Found {key_layer.shape[-2]} in" + f" key_layer and {value_layer.shape[-2]} in value_layer." + ) cp_size = 1 if isinstance(self.cp_group, dist_group_type): @@ -8490,7 +8546,7 @@ def forward( ) # convert qkv layout to its corresponding paged attention layout if inference_params is not None and inference_params.is_paged: - qkv_layout = "paged_kv_"+qkv_format+"_2"+inference_params.qkv_format + qkv_layout = "paged_kv_" + qkv_format + "_2" + inference_params.qkv_format global _alibi_cache if alibi_slopes is not None: @@ -8768,11 +8824,11 @@ def forward( if orig_qkv_format == "bshd": output = output[:batch_size, :max_seqlen_q].contiguous() if orig_qkv_format == "sbhd": - output = output[:batch_size, :max_seqlen_q].transpose(0,1).contiguous() + output = output[:batch_size, :max_seqlen_q].transpose(0, 1).contiguous() if orig_qkv_format == "thd": - packed_output = torch.Tensor().to(dtype=output.dtype,device=output.device) + packed_output = torch.Tensor().to(dtype=output.dtype, device=output.device) for i in range(batch_size): - packed_output = torch.cat([packed_output, output[i,:step_lens[i]]], dim=0) + packed_output = torch.cat([packed_output, output[i, : step_lens[i]]], dim=0) output = packed_output.contiguous() return output @@ -9509,7 +9565,9 @@ def forward( elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) else: - raise ValueError(f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE.") + raise ValueError( + f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE." + ) # TODO: consider cases where sequences have different seqlens sequence_start = inference_params.seqlens[0] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ca672902a2..8ee960b361 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -110,8 +110,7 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, - const c10::optional page_table_v, + const c10::optional page_table_k, const c10::optional page_table_v, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, const c10::optional scale_S, const int scale_S_offset, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 1a41c8cb5a..a68d2086c0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -784,8 +784,7 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, - const c10::optional page_table_v, + const c10::optional page_table_k, const c10::optional page_table_v, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, const c10::optional scale_S, const int scale_S_offset, @@ -890,12 +889,12 @@ std::vector fused_attn_fwd( std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; auto page_table_v_sizes = page_table_v.value().sizes().vec(); std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; - te_page_table_k = makeTransformerEngineTensor(page_table_k.value().data_ptr(), - page_table_k_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_page_table_v = makeTransformerEngineTensor(page_table_v.value().data_ptr(), - page_table_v_shape, DType::kInt32, - nullptr, nullptr, nullptr); + te_page_table_k = + makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_page_table_v = + makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // extract rng seed and offset @@ -915,13 +914,13 @@ std::vector fused_attn_fwd( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, - max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_fwd( + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), + &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), + te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], + workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -957,13 +956,13 @@ std::vector fused_attn_fwd( } // execute the kernel - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, - max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_fwd( + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), + &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), + te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], + workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 2b62ac2067..184d496218 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -7,19 +7,22 @@ from typing import List, Optional import torch + class NonPagedKVCacheManager: """ The non-paged KV cache manager. """ - def __init__(self, - max_batch_size: int, - max_seqlen: int, - num_heads: int, - head_dim_k: int, - dtype: torch.dtype, - head_dim_v: Optional[int] = None, - is_cuda_graph: bool = False, - ): + + def __init__( + self, + max_batch_size: int, + max_seqlen: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + is_cuda_graph: bool = False, + ): """Initialize the KV cache""" self.max_batch_size = max_batch_size self.max_seqlen = max_seqlen @@ -54,7 +57,14 @@ def allocate_memory(self, layer_number): ) self.cache[layer_number] = (k_cache, v_cache) - def step(self, layer_number, k: torch.Tensor, v: torch.Tensor, step_dict: OrderedDict, qkv_format: str): + def step( + self, + layer_number, + k: torch.Tensor, + v: torch.Tensor, + step_dict: OrderedDict, + qkv_format: str, + ): """ Update the non-paged KV cache for a given inference iteration. For more details, please refer to InferenceParams.update_cache(). @@ -86,10 +96,13 @@ def step(self, layer_number, k: torch.Tensor, v: torch.Tensor, step_dict: Ordere # Reorder cache unfinished_seqs = self.sequences.keys() & step_dict.keys() finished_seqs = self.sequences.keys() - unfinished_seqs - unfinished_indices = [i for i,j in enumerate(self.sequences) if j in unfinished_seqs] - finished_indices = [i for i,j in enumerate(self.sequences) if j in finished_seqs] - batch_indices = unfinished_indices + finished_indices \ - + list(range(prev_batch_size, self.max_batch_size)) + unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] + finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] + batch_indices = ( + unfinished_indices + + finished_indices + + list(range(prev_batch_size, self.max_batch_size)) + ) new_k_cache = k_cache[batch_indices, :] new_v_cache = v_cache[batch_indices, :] new_k_cache = new_k_cache.contiguous() @@ -110,19 +123,19 @@ def step(self, layer_number, k: torch.Tensor, v: torch.Tensor, step_dict: Ordere # Copy new key/value tokens to cache step_lens = list(step_dict.values()) - cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1,batch_size+1)] - for i,seq in enumerate(self.sequences): - seq_s = self.sequences[seq] - step_dict[seq] + cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] + for i, seq in enumerate(self.sequences): + seq_s = self.sequences[seq] - step_dict[seq] seq_e = self.sequences[seq] - if qkv_format == 'bshd': - new_k_cache[i, seq_s:seq_e, :, :] = k[i, :step_dict[seq], :, :] - new_v_cache[i, seq_s:seq_e, :, :] = v[i, :step_dict[seq], :, :] - if qkv_format == 'sbhd': - new_k_cache[i, seq_s:seq_e, :, :] = k[:step_dict[seq], i, :, :] - new_v_cache[i, seq_s:seq_e, :, :] = v[:step_dict[seq], i, :, :] - if qkv_format == 'thd': - new_k_cache[i, seq_s:seq_e, :, :] = k[cu_seqlens[i]:cu_seqlens[i+1], :, :] - new_v_cache[i, seq_s:seq_e, :, :] = v[cu_seqlens[i]:cu_seqlens[i+1], :, :] + if qkv_format == "bshd": + new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_dict[seq], :, :] + new_v_cache[i, seq_s:seq_e, :, :] = v[i, : step_dict[seq], :, :] + if qkv_format == "sbhd": + new_k_cache[i, seq_s:seq_e, :, :] = k[: step_dict[seq], i, :, :] + new_v_cache[i, seq_s:seq_e, :, :] = v[: step_dict[seq], i, :, :] + if qkv_format == "thd": + new_k_cache[i, seq_s:seq_e, :, :] = k[cu_seqlens[i] : cu_seqlens[i + 1], :, :] + new_v_cache[i, seq_s:seq_e, :, :] = v[cu_seqlens[i] : cu_seqlens[i + 1], :, :] self.cache[layer_number] = (new_k_cache, new_v_cache) # Return full key/value tensors for attention calculation diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index 8591a843d1..3c75a9fc47 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -12,6 +12,7 @@ class Page(object): """A single page""" + def __init__(self, page_id: int): self.page_id = page_id self.allocated = 0 @@ -22,23 +23,26 @@ def allocate_page(self): def deallocate_page(self): self.allocated = False + class PagedKVCacheManager(object): """ Paged KV cache manager. It supports a set of utilities including adding and removing sequences, and copying new key/value tokens to the cache. Users can overwrite this class for more custom implementations. """ - def __init__(self, - total_num_pages: int, - page_size: int, - num_heads: int, - head_dim_k: int, - dtype: torch.dtype, - max_batch_size: int, - max_seqlen: int, - head_dim_v: Optional[int] = None, - is_cuda_graph: bool = False, - ): + + def __init__( + self, + total_num_pages: int, + page_size: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + max_batch_size: int, + max_seqlen: int, + head_dim_v: Optional[int] = None, + is_cuda_graph: bool = False, + ): """Initialize the KV cache""" self.total_num_pages = total_num_pages self.page_size = page_size @@ -65,27 +69,44 @@ def __init__(self, def allocate_memory(self, layer_number): """Allocate memory for the KV cache""" k_cache = torch.empty( - self.total_num_pages, self.page_size, self.num_heads, self.head_dim_k, - dtype=self.dtype, device=torch.cuda.current_device()) + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) v_cache = torch.empty( - self.total_num_pages, self.page_size, self.num_heads, self.head_dim_v, - dtype=self.dtype, device=torch.cuda.current_device()) + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) self.cache[layer_number] = (k_cache, v_cache) for i in range(self.total_num_pages): self.free_pages.append(Page(i)) if self.is_cuda_graph: self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device='cuda') + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + ) def print_cache(self): """Print KV cache status""" used_pages = [self.get_page_count(seq) for seq in self.sequences] logger = logging.getLogger("PagedAttention") logger.debug("cache status:") - logger.debug(f" total pages: {self.total_num_pages} (used {sum(used_pages)}, free {len(self.free_pages)})") + logger.debug( + f" total pages: {self.total_num_pages} (used {sum(used_pages)}, free" + f" {len(self.free_pages)})" + ) logger.debug(f" total sequences: {self.get_sequence_count()}") for i, seq in enumerate(self.sequences): - logger.debug(f" >> batch index {i}: seq_id {seq}, num_tokens {self.get_sequence_lengths()[i]}, num_pages {self.get_page_count(seq)}, page_list {self.get_page_list(seq)}") + logger.debug( + f" >> batch index {i}: seq_id {seq}, num_tokens {self.get_sequence_lengths()[i]}," + f" num_pages {self.get_page_count(seq)}, page_list {self.get_page_list(seq)}" + ) def get_sequence_count(self): """Get the total number of sequences in the KV cache""" @@ -115,13 +136,16 @@ def get_page_token_offsets(self, seqlen: int): def get_page_table(self, sequences: List[int]): """Get the page table, in shape [batch_size, max_pages_per_seq]""" - page_table = torch.Tensor([self.get_page_list(seq) + \ - [0]*(self.max_pages_per_seq-self.get_page_count(seq)) \ - for seq in sequences]).to(dtype=torch.int32, device='cpu') + page_table = torch.Tensor( + [ + self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) + for seq in sequences + ] + ).to(dtype=torch.int32, device="cpu") if self.is_cuda_graph: - self.page_table[:self.get_sequence_count()].copy_(page_table) + self.page_table[: self.get_sequence_count()].copy_(page_table) else: - self.page_table = page_table.to(device='cuda') + self.page_table = page_table.to(device="cuda") return self.page_table def allocate_page(self, seq: int): @@ -148,7 +172,14 @@ def deallocate_sequence(self, seq: int): self.free_pages.append(page) self.allocated_pages.pop(seq) - def step(self, layer_number: int, k: torch.Tensor, v: torch.Tensor, step_dict: OrderedDict, qkv_format: str): + def step( + self, + layer_number: int, + k: torch.Tensor, + v: torch.Tensor, + step_dict: OrderedDict, + qkv_format: str, + ): """ Update the paged KV cache for a given inference iteration. For more details, please refer to InferenceParams.update_cache(). @@ -175,7 +206,7 @@ def step(self, layer_number: int, k: torch.Tensor, v: torch.Tensor, step_dict: O """ batch_size = len(step_dict) step_lens = list(step_dict.values()) - cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1,batch_size+1)] + cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] # Remove finished sequences and advance unfinished sequences unfinished_seqs = self.sequences.keys() & step_dict.keys() @@ -184,8 +215,7 @@ def step(self, layer_number: int, k: torch.Tensor, v: torch.Tensor, step_dict: O self.sequences.pop(seq) self.deallocate_sequence(seq) for seq in unfinished_seqs: - if (self.sequences[seq] % self.page_size == 0 - and self.sequences[seq] < self.max_seqlen): + if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: self.allocate_page(seq) self.sequences[seq] += 1 @@ -200,41 +230,43 @@ def step(self, layer_number: int, k: torch.Tensor, v: torch.Tensor, step_dict: O packed_k = torch.Tensor([]).to(dtype=k.dtype, device=k.device) packed_v = torch.Tensor([]).to(dtype=v.dtype, device=v.device) for i in range(batch_size): - if qkv_format == 'bshd': - packed_k = torch.cat([packed_k, k[i, :step_lens[i], :, :]], dim=0) - packed_v = torch.cat([packed_v, v[i, :step_lens[i], :, :]], dim=0) - if qkv_format == 'sbhd': - packed_k = torch.cat([packed_k, k[:step_lens[i], i, :, :]], dim=0) - packed_v = torch.cat([packed_v, v[:step_lens[i], i, :, :]], dim=0) - if qkv_format == 'thd': + if qkv_format == "bshd": + packed_k = torch.cat([packed_k, k[i, : step_lens[i], :, :]], dim=0) + packed_v = torch.cat([packed_v, v[i, : step_lens[i], :, :]], dim=0) + if qkv_format == "sbhd": + packed_k = torch.cat([packed_k, k[: step_lens[i], i, :, :]], dim=0) + packed_v = torch.cat([packed_v, v[: step_lens[i], i, :, :]], dim=0) + if qkv_format == "thd": packed_k = k packed_v = v k_cache, v_cache = self.cache[layer_number] - for i,seq in enumerate(step_dict.keys()): + for i, seq in enumerate(step_dict.keys()): page_list = self.get_page_list(seq) - start_page, start_token = self.get_page_token_offsets( - seqlens[i]-step_lens[i]) - end_page, end_token = self.get_page_token_offsets( - seqlens[i]) + start_page, start_token = self.get_page_token_offsets(seqlens[i] - step_lens[i]) + end_page, end_token = self.get_page_token_offsets(seqlens[i]) if start_page == end_page: page_id = page_list[start_page] - k_cache[page_id,start_token:end_token,:,:] = \ - packed_k[cu_seqlens[i]:cu_seqlens[i+1],:,:] - v_cache[page_id,start_token:end_token,:,:] = \ - packed_v[cu_seqlens[i]:cu_seqlens[i+1],:,:] + k_cache[page_id, start_token:end_token, :, :] = packed_k[ + cu_seqlens[i] : cu_seqlens[i + 1], :, : + ] + v_cache[page_id, start_token:end_token, :, :] = packed_v[ + cu_seqlens[i] : cu_seqlens[i + 1], :, : + ] else: start_offset = 0 end_offset = 0 - for j in range(start_page, end_page+1): + for j in range(start_page, end_page + 1): if not (j == end_page and end_token == 0): start_token_j = start_token if j == start_page else 0 end_token_j = end_token if j == end_page else self.page_size page_id = page_list[start_page] end_offset = end_token_j - start_token_j - k_cache[page_id,start_token_j:end_token_j,:,:] = \ - packed_k[cu_seqlens[i]+start_offset:cu_seqlens[i]+end_offset,:,:] - v_cache[page_id,start_token_j:end_token_j,:,:] = \ - packed_v[cu_seqlens[i]+start_offset:cu_seqlens[i]+end_offset,:,:] + k_cache[page_id, start_token_j:end_token_j, :, :] = packed_k[ + cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : + ] + v_cache[page_id, start_token_j:end_token_j, :, :] = packed_v[ + cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : + ] start_offset = start_offset + end_offset # Get page table From b4efd7147288cf28a7cd268eb7831c6e092f14c4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:15:21 -0800 Subject: [PATCH 020/239] remove unnecessary import in test_numerics Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1dc0b56d99..3f09ba9269 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -35,7 +35,6 @@ Fp8Padding, Fp8Unpadding, ) -from transformer_engine.pytorch.attention import _cu_seqlens_cache from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace From e637a079b23eaafbeb5a005b1382e2eecb5b5dfe Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:19:05 -0800 Subject: [PATCH 021/239] add license for test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 9b0cc54d92..ff4a7ea0f4 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + from collections import OrderedDict import os import logging From 767c8f53a0e150c3716caa4fb1ddee144a6f640a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:40:29 -0800 Subject: [PATCH 022/239] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 82 +++++++++---------- .../pytorch/kv_cache_manager_non_paged.py | 14 ++-- .../pytorch/kv_cache_manager_paged.py | 15 ++-- 3 files changed, 53 insertions(+), 58 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c116f567c6..520d1dc89d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -13,11 +13,11 @@ import warnings import logging import functools -from einops import rearrange from dataclasses import dataclass, fields import numpy as np from packaging.version import Version as PkgVersion +from einops import rearrange import torch import torch.nn.functional as F @@ -1134,18 +1134,18 @@ def __init__( def print(self): """Print InferenceParams parameters""" logger = logging.getLogger("InferenceParams") - logger.debug(f"InferenceParams:") - logger.debug(f" dtype: {self.dtype}") - logger.debug(f" is_paged: {self.is_paged}") + logger.debug("InferenceParams:") + logger.debug(" dtype: %s", self.dtype) + logger.debug(" is_paged: %s", self.is_paged) if not self.is_paged: - logger.debug(f" max_batch_size: {self.max_batch_size}") - logger.debug(f" max_seqlen_kv: {self.max_seqlen_kv}") + logger.debug(" max_batch_size: %s", self.max_batch_size) + logger.debug(" max_seqlen_kv: %s", self.max_seqlen_kv) else: - logger.debug(f" total_num_pages: {self.total_num_pages}") - logger.debug(f" page_size: {self.page_size}") - logger.debug(f" num_heads_kv: {self.num_heads_kv}") - logger.debug(f" head_dim: k: {self.head_dim_k}, v: {self.head_dim_v}") - logger.debug(f" layer_numbers: {self.layer_numbers}") + logger.debug(" total_num_pages: %s", self.total_num_pages) + logger.debug(" page_size: %s", self.page_size) + logger.debug(" num_heads_kv: %s", self.num_heads_kv) + logger.debug(" head_dim: k: %s, v: %s", self.head_dim_k, self.head_dim_v) + logger.debug(" layer_numbers: %s", self.layer_numbers) def allocate_memory(self, layer_number): """ @@ -1190,7 +1190,7 @@ def reshape_and_copy_q( self, q: torch.Tensor, source_qkv_format: str, - target_qkv_format: str, + target_qkv_format: str, # pylint: disable=unused-argument layer_number: Optional[int] = None, ): """ @@ -1230,31 +1230,31 @@ def reshape_and_copy_q( # bshd: [actual_batch_size, batch_wide_max_seqlen_q, num_heads_q, head_dim_q] return q - else: - assert ( - layer_number is not None and layer_number in self.layer_numbers - ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" - q_buffer = self.q_buffer[layer_number] - for i in range(actual_batch_size): - if source_qkv_format == "bshd": - q_buffer[i, : seqlens_q[i], :, :] = q[i, : seqlens_q[i], :, :] - if source_qkv_format == "sbhd": - q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] - if source_qkv_format == "thd": - q_buffer[i, : seqlens_q[i], :, :] = q[ - cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : - ] - q_buffer[i, seqlens_q[i] :, :, :].fill_(0) - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * ( - self.max_batch_size - actual_batch_size - ) - self.cu_seqlens_q_buffer.copy_( - torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") - ) + assert ( + layer_number is not None and layer_number in self.layer_numbers + ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" + q_buffer = self.q_buffer[layer_number] + for i in range(actual_batch_size): + if source_qkv_format == "bshd": + q_buffer[i, : seqlens_q[i], :, :] = q[i, : seqlens_q[i], :, :] + if source_qkv_format == "sbhd": + q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] + if source_qkv_format == "thd": + q_buffer[i, : seqlens_q[i], :, :] = q[ + cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : + ] + q_buffer[i, seqlens_q[i] :, :, :].fill_(0) - # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] - return q_buffer + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * ( + self.max_batch_size - actual_batch_size + ) + self.cu_seqlens_q_buffer.copy_( + torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") + ) + + # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] + return q_buffer def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): """ @@ -1379,17 +1379,11 @@ def update_cache( page_table: torch.Tensor The page table if is_paged = True; else `None` """ - outputs = self.cache_manager.step(layer_number, k, v, self.step_dict, qkv_format) + k_cache, v_cache, page_table = self.cache_manager.step(layer_number, k, v, self.step_dict, qkv_format) + self.page_table = page_table self.seq_ids = list(self.cache_manager.sequences.keys()) self.seqlens = list(self.cache_manager.sequences.values()) - if not self.is_paged: - k_cache, v_cache = outputs - page_table = None - else: - k_cache, v_cache, page_table = outputs - self.page_table = page_table - if self.is_cuda_graph: actual_batch_size = len(self.seqlens) cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size + 1)] @@ -8818,7 +8812,6 @@ def forward( if inference_params is not None: batch_size = len(inference_params.step_dict) - seqlen = inference_params.seqlens[0] step_lens = list(inference_params.step_dict.values()) max_seqlen_q = max(list(inference_params.step_dict.values())) if orig_qkv_format == "bshd": @@ -9569,6 +9562,7 @@ def forward( f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE." ) + # pylint: disable=fixme # TODO: consider cases where sequences have different seqlens sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 184d496218..0a72968f67 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -4,7 +4,7 @@ """Non-Paged KV Cache Manager.""" from collections import OrderedDict -from typing import List, Optional +from typing import Optional import torch @@ -141,9 +141,9 @@ def step( # Return full key/value tensors for attention calculation if self.is_cuda_graph: # [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] - return new_k_cache, new_v_cache - else: - # [actual_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] - new_k_cache = new_k_cache[:batch_size].contiguous() - new_v_cache = new_v_cache[:batch_size].contiguous() - return new_k_cache, new_v_cache + return new_k_cache, new_v_cache, None + + # [actual_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] + new_k_cache = new_k_cache[:batch_size].contiguous() + new_v_cache = new_v_cache[:batch_size].contiguous() + return new_k_cache, new_v_cache, None diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index 3c75a9fc47..2425708471 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -10,21 +10,24 @@ import torch -class Page(object): +class Page: """A single page""" def __init__(self, page_id: int): + """Initialize a page""" self.page_id = page_id self.allocated = 0 def allocate_page(self): + """Allocate a page""" self.allocated = True def deallocate_page(self): + """Deallocate a page""" self.allocated = False -class PagedKVCacheManager(object): +class PagedKVCacheManager: """ Paged KV cache manager. It supports a set of utilities including adding and removing sequences, and copying new key/value tokens to the cache. Users can overwrite this class @@ -98,14 +101,12 @@ def print_cache(self): logger = logging.getLogger("PagedAttention") logger.debug("cache status:") logger.debug( - f" total pages: {self.total_num_pages} (used {sum(used_pages)}, free" - f" {len(self.free_pages)})" + " total pages: %s (used %s, free %s)", self.total_num_pages, sum(used_pages), len(self.free_pages) ) - logger.debug(f" total sequences: {self.get_sequence_count()}") + logger.debug(" total sequences: %s", self.get_sequence_count()) for i, seq in enumerate(self.sequences): logger.debug( - f" >> batch index {i}: seq_id {seq}, num_tokens {self.get_sequence_lengths()[i]}," - f" num_pages {self.get_page_count(seq)}, page_list {self.get_page_list(seq)}" + " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", i, seq, self.get_sequence_lengths()[i], self.get_page_count(seq), self.get_page_list(seq) ) def get_sequence_count(self): From a3bb14fe698cc1157535c924aa97920f04825593 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:43:55 -0800 Subject: [PATCH 023/239] add to L0 test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17307574a9..d1139e42f9 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -14,6 +14,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py From d65933c8b39acb9714eadaf4fe740036a62b21a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 05:45:07 +0000 Subject: [PATCH 024/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 14 ++++++-------- .../pytorch/kv_cache_manager_paged.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 520d1dc89d..c278b806f5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1190,7 +1190,7 @@ def reshape_and_copy_q( self, q: torch.Tensor, source_qkv_format: str, - target_qkv_format: str, # pylint: disable=unused-argument + target_qkv_format: str, # pylint: disable=unused-argument layer_number: Optional[int] = None, ): """ @@ -1241,14 +1241,10 @@ def reshape_and_copy_q( if source_qkv_format == "sbhd": q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] if source_qkv_format == "thd": - q_buffer[i, : seqlens_q[i], :, :] = q[ - cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : - ] + q_buffer[i, : seqlens_q[i], :, :] = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, :] q_buffer[i, seqlens_q[i] :, :, :].fill_(0) - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * ( - self.max_batch_size - actual_batch_size - ) + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) self.cu_seqlens_q_buffer.copy_( torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") ) @@ -1379,7 +1375,9 @@ def update_cache( page_table: torch.Tensor The page table if is_paged = True; else `None` """ - k_cache, v_cache, page_table = self.cache_manager.step(layer_number, k, v, self.step_dict, qkv_format) + k_cache, v_cache, page_table = self.cache_manager.step( + layer_number, k, v, self.step_dict, qkv_format + ) self.page_table = page_table self.seq_ids = list(self.cache_manager.sequences.keys()) self.seqlens = list(self.cache_manager.sequences.values()) diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index 2425708471..cf4cba5b71 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -101,12 +101,20 @@ def print_cache(self): logger = logging.getLogger("PagedAttention") logger.debug("cache status:") logger.debug( - " total pages: %s (used %s, free %s)", self.total_num_pages, sum(used_pages), len(self.free_pages) + " total pages: %s (used %s, free %s)", + self.total_num_pages, + sum(used_pages), + len(self.free_pages), ) logger.debug(" total sequences: %s", self.get_sequence_count()) for i, seq in enumerate(self.sequences): logger.debug( - " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", i, seq, self.get_sequence_lengths()[i], self.get_page_count(seq), self.get_page_list(seq) + " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", + i, + seq, + self.get_sequence_lengths()[i], + self.get_page_count(seq), + self.get_page_list(seq), ) def get_sequence_count(self): From d3cbccdf9e98a5b3eb756a61e5c1a744b6daf06f Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Wed, 4 Dec 2024 09:52:54 -0600 Subject: [PATCH 025/239] [JAX] Scale sequence length in CP tests to avoid tiny sizes. (#1347) Scale sequence length in CP tests to avoid tiny sizes. Signed-off-by: Michael Goldfarb --- tests/jax/test_distributed_fused_attn.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 7ef0d68474..e194a228d2 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -341,8 +341,9 @@ def ref_func(query, kv, mask): @pytest.mark.parametrize( "data_shape", [ - pytest.param([2, 512, 12, 128], id="2-512-12-128"), - pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"), + pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ], ) @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) @@ -423,6 +424,12 @@ def impl_test_contex_parallel_attn( qkv_format = get_qkv_format(qkv_layout) batch, seqlen, num_head, hidden = data_shape + + # Scale the sequence length by 2*CP so its never too small as we scale up test. + # 2*CP is used since we split into two CP groups for load balancing. + seqlen = seqlen * cp_size * 2 + data_shape = batch, seqlen, num_head, hidden + num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) From 71ada55fb4639ded1be9e125ec0eee0870a6d7f3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:18:55 -0800 Subject: [PATCH 026/239] Debug nightly docs (#1338) Debug jobs to deploy nightly docs Signed-off-by: Tim Moon --- .github/workflows/deploy_nightly_docs.yml | 3 ++- .github/workflows/docs.yml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index cd68019c8f..fc5e27d0a4 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -16,13 +16,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Download artifact - uses: actions/download-artifact@v4.1.7 + uses: actions/download-artifact@v4 with: name: "te_docs" path: "html" - name: Prepare for pages uses: actions/upload-pages-artifact@v1.0.7 with: + name: github-pages path: "html" deploy: needs: prepare diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 4762cccee6..b6fadba1bd 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: cd docs make html - name: 'Upload docs' - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: te_docs path: docs/_build/html From 8c004241d18cb769f705436fcae8b0789941075d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:19:53 -0800 Subject: [PATCH 027/239] [PyTorch] Store module extra state in tensor (#1335) Store module extra state in tensor Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/base.py | 91 +++++++++++++++++------ transformer_engine/pytorch/ops/op.py | 2 +- 2 files changed, 68 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3a15242c3a..d115efedaa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -588,20 +588,50 @@ def reset(key): def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" - state = None + # This implementation is working around a few issues: + # + # (1) PyTorch's "extra state" infrastructure might be able to + # support any picklable type, but they make no guarantees. + # We have experienced problems (e.g. in ONNX export) with + # non-tensor extra state. + # (2) PyTorch's checkpointing infrastructure does not remap + # devices for "extra state" like it does for "state dict". + # Thus, we want to avoid putting extra state on the GPU + # since it may be loaded on the wrong device. + # (3) The extra state consists of many small tensors. If we + # want to copy them all to CPU, then we need to avoid the + # overhead of many GPU-CPU memory transfers. + # + # See: https://github.com/NVIDIA/TransformerEngine/pull/351 + # See: https://github.com/NVIDIA/TransformerEngine/pull/363 + + def to_cpu(src: torch.Tensor) -> torch.Tensor: + """Helper function to make CPU copy of tensor + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst = torch.empty_like(src, device="cpu") + dst.copy_(src, non_blocking=True) + return dst + + # Store FP8 state if needed + state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - if fp8_checkpoint: + + # Copy tensors to CPU and store state = {} - state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale - state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv - state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history - state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale - state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv - state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - - # Store other pickelable values. + state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) + state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) + state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv) + state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) + state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) + state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv) + + # Store other pickelable values extra = {} for k, v in self.fp8_meta.items(): if k != "buffer_index_and_autocast_key" and isinstance( @@ -610,12 +640,10 @@ def get_extra_state(self) -> torch.Tensor: extra[k] = v state["extra_fp8_variables"] = extra - if is_in_onnx_export_mode(): - state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8) - else: - state_serialized = io.BytesIO() - torch.save(state, state_serialized) - + # Serialize state into byte tensor + torch.cuda.synchronize() + state_serialized = bytearray(pickle.dumps(state)) + state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) return state_serialized def set_extra_state(self, state: torch.Tensor) -> None: @@ -623,9 +651,12 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return + # Load state if isinstance(state, torch.Tensor): + # Default format: byte tensor with pickled data state = pickle.loads(state.detach().cpu().numpy().tobytes()) elif isinstance(state, io.BytesIO): + # Deprecated format with io.BytesIO state.seek(0) state = torch.load(state, map_location="cuda") else: @@ -634,20 +665,32 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return - # Load extra items. + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] - # Initialize before loading. + # Initialize before loading self.init_fp8_meta_tensors() - self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) - self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) - self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) - self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"]) - self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"]) - self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"]) + + def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: + """Helper function to copy tensor from CPU + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst.copy_(src, non_blocking=True) + + # Load tensors + copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) + copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) + copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv) + copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) + copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) + copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv) + torch.cuda.synchronize() def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 04a66b7942..c55e0f7c19 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -514,7 +514,7 @@ def get_extra_state(self) -> torch.Tensor: # # (1) PyTorch's "extra state" infrastructure might be able to # support any picklable type, but they make no guarantees. - # It seems that ONNX export experiences issues with + # We have experienced problems (e.g. in ONNX export) with # non-tensor extra state. # (2) PyTorch's checkpointing infrastructure does not remap # devices for "extra state" like it does for "state dict". From d978e800be44e43e17181a5d5087eeb77e626c50 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:44:40 -0800 Subject: [PATCH 028/239] Fix attention mask type for Flash Attention + CP + THD (#1354) * always have padding mask type for both flash and fused attentions Signed-off-by: Xiaowei Ren * remove an redundant assert Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren --- tests/pytorch/fused_attn/run_fused_attn_with_cp.py | 2 +- transformer_engine/pytorch/attention.py | 14 +++----------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 2d863b3bba..3ddfab055c 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -42,7 +42,7 @@ def run_dpa_with_cp( "causal", "no_mask", ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if kernel_backend == "FusedAttention" and qkv_format == "thd": + if qkv_format == "thd": if "causal" in config.attn_mask_type: config.attn_mask_type = "padding_causal" else: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8159f20e90..8c529c58d0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4309,14 +4309,6 @@ def attn_forward_func_with_cp( assert ( qkv_format != "sbhd" or use_fused_attention ), "FlashAttention does not support sbhd format!" - assert ( - qkv_format != "thd" - or not use_fused_attention - or attn_mask_type in ["padding", "padding_causal"] - ), ( - f"Context parallelism is not supported for {attn_mask_type} mask type and " - f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" - ) assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" @@ -7878,6 +7870,9 @@ def forward( ), f"Values have head_dim = {value_layer.shape[-1]}, " "but expected head_dim = {self.hidden_size_per_attention_head_v}!" + if qkv_format is None: + qkv_format = self.qkv_format + if attn_mask_type is None: attn_mask_type = self.attn_mask_type else: @@ -7904,9 +7899,6 @@ def forward( graph_safe_rng_available() ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." - if qkv_format is None: - qkv_format = self.qkv_format - if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" From d8b13cb0e58060e2428949faa36e625fe2978d97 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:42:06 -0800 Subject: [PATCH 029/239] Disable FP8 in Mcore integration test on older GPUs (#1357) Debug Mcore integration test Avoid FP8 on Ampere and older. Generate synthetic data instead of depending on external data. Signed-off-by: Tim Moon --- qa/L1_pytorch_mcore_integration/.gitignore | 2 ++ qa/L1_pytorch_mcore_integration/merges.txt | 1 + qa/L1_pytorch_mcore_integration/test.sh | 22 ++++++++++++++++++---- 3 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 qa/L1_pytorch_mcore_integration/.gitignore create mode 100644 qa/L1_pytorch_mcore_integration/merges.txt diff --git a/qa/L1_pytorch_mcore_integration/.gitignore b/qa/L1_pytorch_mcore_integration/.gitignore new file mode 100644 index 0000000000..46426003ca --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/.gitignore @@ -0,0 +1,2 @@ +Megatron-LM +vocab.json \ No newline at end of file diff --git a/qa/L1_pytorch_mcore_integration/merges.txt b/qa/L1_pytorch_mcore_integration/merges.txt new file mode 100644 index 0000000000..5e7f1fd949 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/merges.txt @@ -0,0 +1 @@ +#version: 0.2 diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh index 01c9e14eb1..b0aba17ef5 100644 --- a/qa/L1_pytorch_mcore_integration/test.sh +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -8,6 +8,12 @@ set -e : ${TE_PATH:=/opt/transformerengine} : ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} +# Check whether FP8 is supported +DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') +if [[ ${DEVICE_ARCH} -ge 89 ]]; then + WITH_FP8=1 +fi + # Download Megatron-LM if needed if [ ! -d "${MCORE_PATH}" ]; then pushd $(dirname ${MCORE_PATH}) @@ -15,6 +21,14 @@ if [ ! -d "${MCORE_PATH}" ]; then popd fi +# Create mock vocab +VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json +printf "" > ${VOCAB_FILE} +printf "{" >> ${VOCAB_FILE} +printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE} +seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE} +printf "}" >> ${VOCAB_FILE} + # Megatron-LM invocation COMMAND=" NVTE_TORCH_COMPILE=0 @@ -40,17 +54,17 @@ ${MCORE_PATH}/pretrain_gpt.py --hidden-size 128 --num-attention-heads 8 --seq-length 128 ---max-position-embeddings 2048 +--max-position-embeddings 128 --micro-batch-size 1 --global-batch-size 8 --train-iters 10 --eval-iters 10 --lr 1e-4 --mock-data ---vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json ---merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt +--vocab-file ${VOCAB_FILE} +--merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt --transformer-impl transformer_engine ---fp8-format hybrid +${WITH_FP8:+--fp8-format hybrid} " COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') From 3102fdd160703b7bb76dfad291ab97493c320a6b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 6 Dec 2024 13:58:59 -0500 Subject: [PATCH 030/239] [C] Normalization Refactor + Adding CUDNN backend (#1315) * cuDNN normalization integration * TE Norm refactor * TE Norm APIs changes. --------- Signed-off-by: Phuong Nguyen Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/cpp/operator/CMakeLists.txt | 3 +- tests/cpp/operator/test_layernorm.cu | 302 ------------ tests/cpp/operator/test_normalization.cu | 380 +++++++++++++++ tests/cpp/operator/test_rmsnorm.cu | 249 ---------- transformer_engine/common/CMakeLists.txt | 13 +- .../include/transformer_engine/layer_norm.h | 159 ------ .../{rmsnorm.h => normalization.h} | 136 +++--- transformer_engine/common/layer_norm/ln.h | 239 --------- .../common/layer_norm/ln_api.cpp | 457 ------------------ .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 345 ------------- .../common/layer_norm/ln_fwd_cuda_kernel.cu | 413 ---------------- .../common/normalization/common.cpp | 445 +++++++++++++++++ .../common/normalization/common.h | 382 +++++++++++++++ .../kernel_traits.h} | 19 +- .../common/normalization/layernorm/ln_api.cpp | 184 +++++++ .../layernorm}/ln_bwd_kernels.cuh | 25 +- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 331 +++++++++++++ .../layernorm/ln_fwd_cuda_kernel.cu | 395 +++++++++++++++ .../layernorm}/ln_fwd_kernels.cuh | 17 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 166 +++++++ .../rmsnorm/rmsnorm_bwd_kernels.cuh | 19 +- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 206 ++++++++ .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 210 ++++++++ .../rmsnorm/rmsnorm_fwd_kernels.cuh | 12 +- transformer_engine/common/rmsnorm/rmsnorm.h | 89 ---- .../common/rmsnorm/rmsnorm_api.cpp | 387 --------------- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 220 --------- .../common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 227 --------- .../common/rmsnorm/rmsnorm_kernel_traits.h | 42 -- .../jax/cpp_extensions/normalization.py | 181 ++----- transformer_engine/jax/csrc/extensions.h | 15 +- .../jax/csrc/extensions/normalization.cpp | 270 +++-------- .../jax/csrc/extensions/packing.cpp | 15 +- transformer_engine/paddle/csrc/common.h | 3 +- transformer_engine/paddle/csrc/custom_ops.cu | 98 ++-- transformer_engine/pytorch/csrc/common.h | 3 +- .../pytorch/csrc/extensions/normalization.cpp | 99 ++-- 37 files changed, 3029 insertions(+), 3727 deletions(-) delete mode 100644 tests/cpp/operator/test_layernorm.cu create mode 100644 tests/cpp/operator/test_normalization.cu delete mode 100644 tests/cpp/operator/test_rmsnorm.cu delete mode 100644 transformer_engine/common/include/transformer_engine/layer_norm.h rename transformer_engine/common/include/transformer_engine/{rmsnorm.h => normalization.h} (55%) delete mode 100644 transformer_engine/common/layer_norm/ln.h delete mode 100644 transformer_engine/common/layer_norm/ln_api.cpp delete mode 100644 transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu delete mode 100644 transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu create mode 100644 transformer_engine/common/normalization/common.cpp create mode 100644 transformer_engine/common/normalization/common.h rename transformer_engine/common/{layer_norm/ln_kernel_traits.h => normalization/kernel_traits.h} (89%) create mode 100644 transformer_engine/common/normalization/layernorm/ln_api.cpp rename transformer_engine/common/{layer_norm => normalization/layernorm}/ln_bwd_kernels.cuh (97%) create mode 100644 transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu create mode 100644 transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu rename transformer_engine/common/{layer_norm => normalization/layernorm}/ln_fwd_kernels.cuh (97%) create mode 100644 transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp rename transformer_engine/common/{ => normalization}/rmsnorm/rmsnorm_bwd_kernels.cuh (97%) create mode 100644 transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu create mode 100644 transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu rename transformer_engine/common/{ => normalization}/rmsnorm/rmsnorm_fwd_kernels.cuh (98%) delete mode 100644 transformer_engine/common/rmsnorm/rmsnorm.h delete mode 100644 transformer_engine/common/rmsnorm/rmsnorm_api.cpp delete mode 100644 transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu delete mode 100644 transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu delete mode 100644 transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 45806e7022..ab6b6a5316 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,8 +10,7 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu - test_layernorm.cu - test_rmsnorm.cu + test_normalization.cu test_multi_cast_transpose.cu test_multi_padding.cu test_causal_softmax.cu diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu deleted file mode 100644 index cdd8e7846c..0000000000 --- a/tests/cpp/operator/test_layernorm.cu +++ /dev/null @@ -1,302 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -template -void compute_ref_stats(const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon) { - using compute_t = float; - for (size_t i = 0 ; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += current; - } - mu[i] = sum / H; - compute_t m = mu[i]; - sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += (current - m) * (current - m); - } - sum = sum / H; - compute_t rs = rsqrtf(sum + epsilon); - rsigma[i] = rs; - } -} - -template -void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta, - OutputType *output, const float *mu, const float *rsigma, - const size_t N, const size_t H, - float *amax, float scale, const bool zero_centered_gamma) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0 ; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - compute_t tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - -template -void compute_ref_backward(const OutputType *output_grad, const InputType *data, - const float *mu, const float *rsigma, - const InputType *gamma, - InputType *data_grad, - InputType *gamma_grad, InputType *beta_grad, - const size_t N, const size_t H, - const bool zero_centered_gamma) { - using compute_t = float; - std::vector dgamma(H, 0.f); - std::vector dbeta(H, 0.f); - - for (size_t i = 0 ; i < N; ++i) { - // Reductions - compute_t mdy = 0, mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - mu[i]) * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - dbeta[j] += dz; - mdy += dy; - mdyy += dy * y; - } - mdy /= H; - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - mu[i]) * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) { - gamma_grad[j] = static_cast(dgamma[j]); - beta_grad[j] = static_cast(dbeta[j]); - } -} - -template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) { - if (sizeof(InputType) < sizeof(OutputType)) { - GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; - return; - } - using WeightType = InputType; - DType itype = TypeInfo::dtype; - DType wtype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || - (itype == DType::kFloat16 && otype == DType::kBFloat16)) { - GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; - return; - } - - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); - Tensor dz({ N, H }, wtype); - Tensor dx({ N, H }, itype); - Tensor dgamma({ H }, wtype); - Tensor dbeta({ H }, wtype); - Tensor workspace, barrier, dgamma_part, dbeta_part; - - fillUniform(&input); - fillUniform(&gamma); - fillUniform(&beta); - setRandomScale(&z); - fillUniform(&dz); - - std::unique_ptr ref_output = std::make_unique(N * H); - std::unique_ptr ref_mu = std::make_unique(N); - std::unique_ptr ref_rsigma = std::make_unique(N); - std::unique_ptr ref_dx = std::make_unique(N * H); - std::unique_ptr ref_dgamma = std::make_unique(H); - std::unique_ptr ref_dbeta = std::make_unique(H); - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - - // Forward kernel - float epsilon = 1e-5; - auto fwd_function = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - fwd_function(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - fwd_function(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - - // Backward kernel - auto bwd_function = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_function(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - dgamma_part.data(), dbeta_part.data(), - 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype()); - bwd_function(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - dgamma_part.data(), dbeta_part.data(), - 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - - // Reference implementations - // use the GPU stats to tighten the tolerances - mu.to_cpu(); - rsigma.to_cpu(); - float ref_amax; - compute_ref_stats(input.cpu_dptr(), ref_mu.get(), - ref_rsigma.get(), N, H, epsilon); - float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(input.cpu_dptr(), - gamma.cpu_dptr(), - beta.cpu_dptr(), - ref_output.get(), - mu.cpu_dptr(), - rsigma.cpu_dptr(), - N, H, - &ref_amax, - ref_scale, - zero_centered_gamma); - compute_ref_backward(dz.cpu_dptr(), input.cpu_dptr(), - mu.cpu_dptr(), rsigma.cpu_dptr(), - gamma.cpu_dptr(), - ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), - N, H, zero_centered_gamma); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - if (isFp8Type(otype)) { - compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); - } - - auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); - rtol_stats = 5e-5; - compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); - - auto [atol, rtol] = getTolerances(otype); - if (otype == DType::kFloat32) { - atol = 5e-7; - } - compareResults("output", z, ref_output.get(), atol, rtol); - - double atol_bwd = 1e-4; - double rtol_bwd = 1e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); - compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); -} - -std::vector> test_cases = {{2048, 12288}, - {768, 1024}, - {256, 65536}, - {128, 6144}, - {64, 2304}, - {229, 541}, // Primes 50, 100 - {71, 3571}, // Primes 20, 500 - {29, 17389}}; // Primes 10, 2000 - -} // namespace - -class LNTestSuite : public ::testing::TestWithParam, - bool>> {}; - -TEST_P(LNTestSuite, TestLN) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - const bool zero_centered_gamma = std::get<3>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma); - ); - ); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - LNTestSuite, - ::testing::Combine( - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo& info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param)); - return name; - }); diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu new file mode 100644 index 0000000000..bd6ee96af8 --- /dev/null +++ b/tests/cpp/operator/test_normalization.cu @@ -0,0 +1,380 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RmsNorm"} +}; + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + compute_t current, m; + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +// For now, cudnn does static_cast(gamma + static_cast(1.0)) +// This will be changed in the future release +template +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){ + + using compute_t = float; + if constexpr (std::is_same_v || std::is_same_v){ + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } else { + if (use_cudnn){ + compute_t g = static_cast(0.f); + InputType gi = gamma; + if (zero_centered_gamma) { + gi = gi + static_cast(1.f); + } + g = static_cast(gi); + return g; + } else { + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + OutputType* output, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = static_cast(tmp * scale); + current_max = fmaxf(current_max, fabsf(tmp)); + } + } + *amax = current_max; +} + + +template +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, + const float *mu, const float *rsigma, + const InputType *gamma, + InputType *data_grad, + InputType *gamma_grad, InputType *beta_grad, + const size_t N, const size_t H, + const bool zero_centered_gamma, const bool use_cudnn) { + using compute_t = float; + std::vector dgamma(H, 0.f); + std::vector dbeta(H, 0.f); + + for (size_t i = 0 ; i < N; ++i) { + // Reductions + auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; + compute_t mdy = 0, mdyy = 0; + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + dgamma[j] += y * dz; + if (norm_type == LayerNorm) { + dbeta[j] += dz; + mdy += dy; + } + mdyy += dy * y; + } + mdy /= H; + mdyy /= H; + + // Input grads + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + data_grad[i * H + j] = static_cast(dx); + } + } + + // Weight grads + for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); + if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); +} + +template +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, + NormType norm_type, bool use_cudnn) { + if (sizeof(InputType) < sizeof(OutputType)) { + GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; + return; + } + using WeightType = InputType; + DType itype = TypeInfo::dtype; + DType wtype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || + (itype == DType::kFloat16 && otype == DType::kBFloat16)) { + GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; + return; + } + + Tensor input({ N, H }, itype); + Tensor z({ N, H }, otype); + Tensor gamma({ H }, wtype); + Tensor beta({ H }, wtype); + Tensor mu({ N }, DType::kFloat32); + Tensor rsigma({ N }, DType::kFloat32); + Tensor dz({ N, H }, wtype); + Tensor dx({ N, H }, itype); + Tensor dgamma({ H }, wtype); + Tensor dbeta({ H }, wtype); + Tensor workspace_fwd, workspace_bwd; + + fillUniform(&input); + fillUniform(&gamma); + fillUniform(&beta); + setRandomScale(&z); + fillUniform(&dz); + + std::unique_ptr ref_output = std::make_unique(N * H); + std::unique_ptr ref_mu = std::make_unique(N); + std::unique_ptr ref_rsigma = std::make_unique(N); + std::unique_ptr ref_dx = std::make_unique(N * H); + std::unique_ptr ref_dgamma = std::make_unique(H); + std::unique_ptr ref_dbeta = std::make_unique(H); + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + + if (use_cudnn){ + nvte_enable_cudnn_norm_fwd(true); + nvte_enable_cudnn_norm_bwd(true); + } + + // Forward kernel + float epsilon = 1e-5; + if (norm_type == LayerNorm){ + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } + + if (use_cudnn){ + nvte_enable_cudnn_norm_fwd(false); + nvte_enable_cudnn_norm_bwd(false); + } + + // Reference implementations + // use the GPU stats to tighten the tolerances + mu.to_cpu(); + rsigma.to_cpu(); + float ref_amax; + compute_ref_stats(norm_type, input.cpu_dptr(), ref_mu.get(), + ref_rsigma.get(), N, H, epsilon); + float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; + compute_ref_output(norm_type, input.cpu_dptr(), + gamma.cpu_dptr(), + beta.cpu_dptr(), + ref_output.get(), + mu.cpu_dptr(), + rsigma.cpu_dptr(), + N, H, + &ref_amax, + ref_scale, + zero_centered_gamma, + use_cudnn); + compute_ref_backward(norm_type, dz.cpu_dptr(), input.cpu_dptr(), + mu.cpu_dptr(), rsigma.cpu_dptr(), + gamma.cpu_dptr(), + ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), + N, H, zero_centered_gamma, + use_cudnn); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + if (isFp8Type(otype)) { + compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); + rtol_stats = 5e-5; + compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); + + auto [atol, rtol] = getTolerances(otype); + if (otype == DType::kFloat32) { + atol = 5e-7; + } + compareResults("output", z, ref_output.get(), atol, rtol); + + double atol_bwd = 5e-4; + double rtol_bwd = 5e-4; + compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); + compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); + compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); +} + +std::vector> test_cases = { + {71, 229}, + {29, 541}, + {768, 6144}, + {2048, 12288}, +}; + +} // namespace + +class NormTestSuite : public ::testing::TestWithParam, + bool>> {}; + +TEST_P(NormTestSuite, TestNorm) { + using namespace transformer_engine; + using namespace test; + + const bool use_cudnn = std::get<0>(GetParam()); + const NormType norm_type = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const auto size = std::get<4>(GetParam()); + const bool zero_centered_gamma = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + NormTestSuite, + ::testing::Combine( + ::testing::Values(false), //TODO: enabling tests for cudnn backend + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; +std::string name = + backend + + normToString.at(std::get<1>(info.param)) + "_" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + std::to_string(std::get<4>(info.param).first) + "X" + + std::to_string(std::get<4>(info.param).second) + "X" + + std::to_string(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu deleted file mode 100644 index 0ec3a877e5..0000000000 --- a/tests/cpp/operator/test_rmsnorm.cu +++ /dev/null @@ -1,249 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -template -void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, const size_t H, - const double epsilon) { - using compute_t = float; - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += (current) * (current); - } - sum = sum / H; - compute_t rs = rsqrtf(sum + epsilon); - rsigma[i] = rs; - } -} - -template -void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output, - const float *rsigma, const size_t N, const size_t H, float *amax, - float scale, const bool zero_centered_gamma) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - compute_t tmp = current * rsigma[i] * g; - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - -template -void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma, - const InputType *gamma, InputType *data_grad, InputType *gamma_grad, - const size_t N, const size_t H, const bool zero_centered_gamma) { - using compute_t = float; - std::vector dgamma(H, 0.f); - - for (size_t i = 0; i < N; ++i) { - // Reductions - compute_t mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = x * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - mdyy += dy * y; - } - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = x * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) { - gamma_grad[j] = static_cast(dgamma[j]); - } -} - -template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) { - if (sizeof(InputType) < sizeof(OutputType)) { - GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType"; - return; - } - using WeightType = InputType; - DType itype = TypeInfo::dtype; - DType wtype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || - (itype == DType::kFloat16 && otype == DType::kBFloat16)) { - GTEST_SKIP() << "RMSNorm kernel does not support mixing Float16 and BFloat16"; - return; - } - - Tensor input({N, H}, itype); - Tensor z({N, H}, otype); - Tensor gamma({H}, wtype); - Tensor rsigma({N}, DType::kFloat32); - Tensor dz({N, H}, wtype); - Tensor dx({N, H}, itype); - Tensor dgamma({H}, wtype); - Tensor workspace, barrier, dgamma_part; - - fillUniform(&input); - fillUniform(&gamma); - fillUniform(&dz); - setRandomScale(&z); - - std::unique_ptr ref_output = std::make_unique(N * H); - std::unique_ptr ref_rsigma = std::make_unique(N); - std::unique_ptr ref_dx = std::make_unique(N * H); - std::unique_ptr ref_dgamma = std::make_unique(H); - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - - // Forward kernel - float epsilon = 1e-5; - auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; - fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, - prop.multiProcessorCount, workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, - prop.multiProcessorCount, workspace.data(), barrier.data()); - - // Backward kernel - auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), - dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), - barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); - bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), - dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), - barrier.data()); - - // Reference implementations - // use the GPU stats to tighten the tolerances - rsigma.to_cpu(); - float ref_amax; - compute_ref_stats(input.cpu_dptr(), ref_rsigma.get(), N, H, epsilon); - float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(input.cpu_dptr(), gamma.cpu_dptr(), ref_output.get(), - rsigma.cpu_dptr(), N, H, &ref_amax, ref_scale, - zero_centered_gamma); - compute_ref_backward(dz.cpu_dptr(), input.cpu_dptr(), - rsigma.cpu_dptr(), gamma.cpu_dptr(), ref_dx.get(), - ref_dgamma.get(), N, H, zero_centered_gamma); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - if (isFp8Type(otype)) { - compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); - } - - auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); - rtol_stats = 5e-5; - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); - - auto [atol, rtol] = getTolerances(otype); - atol = 1e-8; - compareResults("output", z, ref_output.get(), atol, rtol); - - double atol_bwd = 5e-6; - double rtol_bwd = 1e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); -} - -std::vector> test_cases = { - {2048, 4096}, {768, 2048}, {256, 1024}, {128, 768}, {64, 512}, {173, 409}, // Primes 40, 80 - {71, 3571}, // Primes 20, 500 - {29, 17389}}; // Primes 10, 2000 - -} // namespace - -class RMSNormTestSuite : public ::testing::TestWithParam, - bool>> {}; - -TEST_P(RMSNormTestSuite, TestRMSNorm) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - const bool zero_centered_gamma = std::get<3>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma););); -} - -INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite, - ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, - DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, - DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo &info) { - std::string name = - test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param)); - return name; - }); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 759c1c19ae..84fc567cd3 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -64,13 +64,14 @@ list(APPEND transformer_engine_SOURCES fused_attn/thd_utils.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu - layer_norm/ln_api.cpp - layer_norm/ln_bwd_semi_cuda_kernel.cu - layer_norm/ln_fwd_cuda_kernel.cu + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/layernorm/ln_bwd_semi_cuda_kernel.cu + normalization/layernorm/ln_fwd_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_api.cpp + normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu - rmsnorm/rmsnorm_api.cpp - rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - rmsnorm/rmsnorm_fwd_cuda_kernel.cu util/cast.cu util/padding.cu util/cuda_driver.cpp diff --git a/transformer_engine/common/include/transformer_engine/layer_norm.h b/transformer_engine/common/include/transformer_engine/layer_norm.h deleted file mode 100644 index 3bb4d47f29..0000000000 --- a/transformer_engine/common/include/transformer_engine/layer_norm.h +++ /dev/null @@ -1,159 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file layer_norm.h - * \brief LayerNorm functions. - */ - -#ifndef TRANSFORMER_ENGINE_LAYER_NORM_H_ -#define TRANSFORMER_ENGINE_LAYER_NORM_H_ - -#include "transformer_engine.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief Compute LayerNorm on the input. - * - * The formula used: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta - * @f] - * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. - * - * \param[in] x Input tensor of shape [N, H]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[in] beta Beta tensor of shape [H]. - * \param[in] epsilon Value added to denominator for numerical stability. - * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[out] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier); - -/*! \brief Compute LayerNorm with zero-centered gamma on the input. - * - * The formula used: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta - * @f] - * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. - * - * \param[in] x Input tensor of shape [N, H]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[in] beta Beta tensor of shape [H]. - * \param[in] epsilon Value added to denominator for numerical stability. - * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[out] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier); - -/*! \brief Compute backward of LayerNorm. - * - * This function computes the gradient of function: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta - * @f] - * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. - * - * Calling this function with workspace, barrier, dgamma_part and dbeta_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[in] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dbeta Gradient for beta tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[out] dbeta_part Storage for partial bias gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, - NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); - -/*! \brief Compute backward of LayerNorm with zero-centered gamma. - * - * This function computes the gradient of function: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta - * @f] - * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. - * - * Calling this function with workspace, barrier, dgamma_part and dbeta_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[in] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dbeta Gradient for beta tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[out] dbeta_part Storage for partial bias gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, - NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TRANSFORMER_ENGINE_LAYER_NORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/rmsnorm.h b/transformer_engine/common/include/transformer_engine/normalization.h similarity index 55% rename from transformer_engine/common/include/transformer_engine/rmsnorm.h rename to transformer_engine/common/include/transformer_engine/normalization.h index dc995e3c24..de9644792b 100644 --- a/transformer_engine/common/include/transformer_engine/rmsnorm.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -4,12 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file rmsnorm.h - * \brief RMSNorm functions. +/*! \file normalization.h + * \brief LayerNorm and RMSNorm functions. */ -#ifndef TRANSFORMER_ENGINE_RMSNORM_H_ -#define TRANSFORMER_ENGINE_RMSNORM_H_ +#ifndef TRANSFORMER_ENGINE_NORMALIZATION_H_ +#define TRANSFORMER_ENGINE_NORMALIZATION_H_ #include "transformer_engine.h" @@ -17,41 +17,73 @@ extern "C" { #endif -/*! \brief Compute RMSNorm on the input. +/*! \brief Compute LayerNorm on the input. * * The formula used: * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}\gamma - * @f] - * where - * @f[ - * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta * @f] * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. * \param[in] gamma Gamma tensor of shape [H]. + * \param[in] beta Beta tensor of shape [H]. * \param[in] epsilon Value added to denominator for numerical stability. * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] rsigma Reciprocal of the root mean square of the input - * calculated over the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. + * \param[out] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[out] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[out] workspace Workspace tensor. * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); + +/*! \brief Compute backward of LayerNorm. + * + * This function computes the gradient of function: + * @f[ + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta + * @f] + * else + * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. + * + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of these tensors to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[in] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] dbeta Gradient for beta tensor of shape [H]. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ -void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, - NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); +void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor mu, + const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, + NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream); -/*! \brief Compute RMSNorm with zero-centered gamma on the input. +/*! \brief Compute RMSNorm. * * The formula used: * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma) + * y = \frac{x}{RMS_\varepsilon(x)}\gamma * @f] * where * @f[ @@ -68,14 +100,14 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep * \param[in,out] z Output tensor of shape [N, H]. * \param[out] rsigma Reciprocal of the root mean square of the input * calculated over the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ -void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, - NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); +void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, + NVTETensor rsigma, NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); /*! \brief Compute backward of RMSNorm. * @@ -100,53 +132,25 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float * \param[in] gamma Gamma tensor of shape [H]. * \param[out] dx Output gradient of shape [N, H]. * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); -/*! \brief Compute backward of RMSNorm with zero-centered gamma. +/*! \brief Helper to enable cuDNN backend for normalization * - * This function computes the gradient of function: - * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma) - * @f] - * where - * @f[ - * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} - * @f] - * with respect to \f$x\f$ and \f$gamma\f$. - * - * Calling this function with workspace, barrier, dgamma_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] rsigma Reciprocal of the root mean square of the input - * calculated over the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] bool Enable if True */ -void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, - const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); +void nvte_enable_cudnn_norm_fwd(bool enable); +void nvte_enable_cudnn_norm_bwd(bool enable); #ifdef __cplusplus } // extern "C" #endif -#endif // TRANSFORMER_ENGINE_RMSNORM_H_ +#endif // TRANSFORMER_ENGINE_NORMALIZATION_H_ diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h deleted file mode 100644 index 13543a10aa..0000000000 --- a/transformer_engine/common/layer_norm/ln.h +++ /dev/null @@ -1,239 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" - -namespace transformer_engine { -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams { - size_t workspace_bytes; - size_t barrier_size; - - int multiprocessorCount; - cudaStream_t stream; - - Params params; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0), - rows(0), - cols(0), - x(nullptr), - mu(nullptr), - rs(nullptr), - gamma(nullptr), - workspace(nullptr), - barrier(nullptr), - zero_centered_gamma(false) {} - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - // Size of CTA group. - int ctas_per_row; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mu; - void *rs; - void *gamma; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - - // Whether gamma is centered around 0 - bool zero_centered_gamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} - - // Output of LN FWD. - void *z; - void *beta; - float epsilon; - - // Scaling factor - void *scale; - - // AMax output - void *amax; - - // Inverse of scaling factor - void *scale_inv; - - // Whether to compute scale and amax - bool fp8_out; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase(), - dz(nullptr), - dbeta_part(nullptr), - dgamma_part(nullptr), - dx(nullptr), - dbeta(nullptr), - dgamma(nullptr) {} - - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId {}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 0; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 1; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 2; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 3; -}; - -template -struct Type2Key { - constexpr static uint32_t Value = TypeId::Value << S; -}; - -template -struct WeightType2Key : public Type2Key {}; - -template -struct InputType2Key : public Type2Key {}; - -template -struct OutputType2Key : public Type2Key {}; - -template -struct ComputeType2Key : public Type2Key {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key { - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | - OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size) { - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp deleted file mode 100644 index 8a40450e59..0000000000 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ /dev/null @@ -1,457 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include -#include - -#include "../common.h" -#include "ln.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 -bf16 fp32 bf16 fp8 - -Remarks: -Output type = Weight type -Compute always in FP32 - -*/ - -namespace transformer_engine { -namespace layer_norm { - -using namespace transformer_engine; - -// Create registries and provide runtime versions of config hash functions. - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint32_t get_type_id(DType dtype) { - if (dtype == DType::kFloat16) { - return TypeId::Value; - } else if (dtype == DType::kBFloat16) { - return TypeId::Value; - } else if (dtype == DType::kFloat32) { - return TypeId::Value; - } else if (dtype == DType::kFloat8E4M3) { - return TypeId::Value; - } else { - NVTE_ERROR("Type not supported."); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | - (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) && - is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) && - is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) && - is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) && - layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t product(const std::vector& shape) { - size_t ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; -} - -} // namespace layer_norm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void layernorm_fwd(const Tensor& x, // BxSxhidden_size - const Tensor& gamma, // hidden_size - const Tensor& beta, // hidden_size - const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream, - const int multiprocessorCount, Tensor* workspace, Tensor* barrier, - const bool zero_centered_gamma) { - const auto itype = x.data.dtype; - const auto wtype = gamma.data.dtype; - const auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - const auto ctype = layer_norm::DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(hidden_size == cols); - - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(mu->data.shape == std::vector{rows}); - NVTE_CHECK(mu->data.dtype == ctype); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - layer_norm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - layer_norm::FwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu->data.dptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = beta.data.dptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - CheckInputTensor(beta, "beta"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*mu, "mu"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; -} - -void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, - const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, - Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream, - const int multiprocessorCount, Tensor* workspace, Tensor* barrier, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(mu.data.dtype == ctype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - auto rows = x.data.shape[0]; - auto cols = x.data.shape[1]; - - auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(mu.data.shape[0] == rows); - NVTE_CHECK(mu.data.shape == rsigma.data.shape); - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - NVTE_CHECK(dbeta->data.shape == gamma.data.shape); - NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); - - layer_norm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - layer_norm::BwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu.data.dptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = dbeta->data.dptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = dbeta_part->data.dptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - NVTE_CHECK(dbeta_part->data.dptr == nullptr); - - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - dbeta_part->data.dtype = ctype; - dbeta_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(dbeta_part->data.dptr != nullptr); - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - NVTE_CHECK(dbeta_part->data.dtype == ctype); - NVTE_CHECK(dbeta_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(mu, "mu"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - CheckOutputTensor(*dbeta, "dbeta"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); -} -} // namespace transformer_engine - -void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm_fwd); - using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, - NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm_bwd); - using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), - stream, multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm1p_fwd); - using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} - -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, - NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm1p_bwd); - using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), - stream, multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} diff --git a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 17f1256910..0000000000 --- a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,345 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "ln.h" -#include "ln_bwd_kernels.cuh" -#include "ln_kernel_traits.h" - -using namespace transformer_engine::layer_norm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel; - kernel_f<<>>( - launch_params.params); -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &ln_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -// Create general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu deleted file mode 100644 index 0c85f4aeb7..0000000000 --- a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu +++ /dev/null @@ -1,413 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "ln.h" -#include "ln_fwd_kernels.cuh" -#include "ln_kernel_traits.h" - -using namespace transformer_engine::layer_norm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); - } -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); - -// Create general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8); - -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp new file mode 100644 index 0000000000..5b6beb66b1 --- /dev/null +++ b/transformer_engine/common/normalization/common.cpp @@ -0,0 +1,445 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* #include */ + +#include "common.h" + +#include +#include +#include +#include +#include + +#include "transformer_engine/normalization.h" + +/* + +Supported Type combinations: + +input compute weights output +======================================= +fp32 fp32 fp32 fp32 +fp16 fp32 fp16 fp16 +bf16 fp32 bf16 bf16 +fp32 fp32 fp16 fp16 +fp32 fp32 bf16 bf16 +bf16 fp32 bf16 fp8 + +Remarks: +Output type = Weight type +Compute always in FP32 + +*/ + +namespace transformer_engine { +namespace normalization { + +TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, + bool zero_centered_gamma, bool is_tuned) { + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | + (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | + (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | + (uint32_t(zero_centered_gamma) << 16); + return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); +} + +template +TeNormalizationPlan::TeNormalizationPlan( + NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, + DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, + const bool zero_centered_gamma, const bool is_tuned) + : _is_layernorm(NormType == NVTE_Norm_Type::LayerNorm) { + _launch_params.multiprocessorCount = sm_count; + + auto& kernel_params = _launch_params.params; + kernel_params.rows = batch_size; + kernel_params.cols = hidden_size; + kernel_params.zero_centered_gamma = zero_centered_gamma; + if constexpr (std::is_same_v) { + kernel_params.fp8_out = is_fp8_dtype(otype); + } + // TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those + auto key = + get_key(NormType, NormStage, wtype, itype, otype, ctype, 0, hidden_size, false, is_tuned); + _kernel = KernelRegistry::getKernel(key); + + this->_build(); +} + +template <> +void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, + void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + _launch_params.stream = stream; + + auto& kernel_params = _launch_params.params; + kernel_params.workspace = workspace_dptr; + kernel_params.x = x_dptr; + kernel_params.rs = rsigma_dptr; + kernel_params.gamma = gamma_dptr; + kernel_params.z = z->data.dptr; + kernel_params.epsilon = *reinterpret_cast(eps_dptr); + kernel_params.amax = z->amax.dptr; + kernel_params.scale = z->scale.dptr; + kernel_params.scale_inv = z->scale_inv.dptr; + + if (_is_layernorm) { + kernel_params.mu = mean_dptr; + kernel_params.beta = beta_dptr; + } + + _set_workspace(); + _kernel(_launch_params, false); +} + +template <> +void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, + void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + NVTE_ERROR("Backward normalization should not call the forward execute function!"); +} + +template +void TeNormalizationPlan::_build() { + _kernel(_launch_params, true); + _launch_params.alignWorkspace(); +} + +template +std::vector TeNormalizationPlan::getWorkspaceShape() const { + return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)}; +} + +template +void TeNormalizationPlan::_set_workspace() { + if (_launch_params.getTotalWorkspaceBytes() > 0) { + auto workspace_dptr = reinterpret_cast(_launch_params.params.workspace); + + if (_launch_params.barrier_bytes > 0) { + _launch_params.params.barrier = + reinterpret_cast(workspace_dptr + _launch_params.workspace_bytes); + cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, + _launch_params.stream); + } + if constexpr (std::is_same_v) { + _launch_params.params.dgamma_part = + workspace_dptr + _launch_params.workspace_bytes + _launch_params.barrier_bytes; + if (_is_layernorm) { + _launch_params.params.dbeta_part = + reinterpret_cast(_launch_params.params.dgamma_part) + + _launch_params.dgamma_part_bytes; + } + } + } +} + +template <> +void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, + void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + NVTE_ERROR("Forward normalization should not call the backward execute function!"); +} + +template <> +void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, + void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + _launch_params.stream = stream; + + auto& kernel_params = _launch_params.params; + kernel_params.workspace = workspace_dptr; + kernel_params.x = x_dptr; + kernel_params.gamma = gamma_dptr; + kernel_params.rs = rsigma_dptr; + kernel_params.dx = dx_dptr; + kernel_params.dz = dz_dptr; + kernel_params.dgamma = dgamma_dptr; + + if (_is_layernorm) { + kernel_params.mu = mean_dptr; + kernel_params.dbeta = dbeta_dptr; + } + + _set_workspace(); + _kernel(_launch_params, false); +} + +CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, DType ctype, + const size_t batch_size, const size_t hidden_size, + const size_t sm_count, + const bool zero_centered_gamma) + : _fp8_out(is_fp8_dtype(otype)), _zero_centered(zero_centered_gamma) { + static_assert(CUDNN_FRONTEND_VERSION >= 10601, + "CUDNN_FRONTEND_VERSION should be at least 1.6.1!"); + + namespace fe = cudnn_frontend; + + _scalar_dptr = std::make_unique(typeToSize(wtype)); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); + + _handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + + _graph.set_io_data_type(get_cudnn_fe_dtype(itype)) + .set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + + if (cudnnGetVersion() >= 90400) _graph.set_sm_count(sm_count); + + const auto batch_dim = static_cast(batch_size); + const auto hidden_dim = static_cast(hidden_size); + + // Create graph tensors + _x = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(itype))); + + _gamma_zero = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("gamma_zero") + .set_dim({1, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + if (zero_centered_gamma) { + _scalar_offset = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("one") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(wtype)) + .set_is_pass_by_value(true)); + auto centered_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); + _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); + } else { + _gamma = _gamma_zero; + } + + // Create graph computation nodes + if (NormStage == NVTE_Norm_Stage::Forward) { + _eps = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("epsilon") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype)) + .set_is_pass_by_value(true)); + if (NormType == NVTE_Norm_Type::LayerNorm) { + _beta = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({1, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + auto norm_options = fe::graph::Layernorm_attributes() + .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_epsilon(_eps) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options); + std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]); + _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else if (NormType == NVTE_Norm_Type::RMSNorm) { + auto norm_options = fe::graph::Rmsnorm_attributes() + .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_epsilon(_eps) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.rmsnorm(_x, _gamma, norm_options); + std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]); + } + + _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + + const auto ZDtype = _fp8_out ? ctype : otype; + _z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype)); + + if (_fp8_out) { + // create a scale node + _z_scale = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("z_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + auto z_scale_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); + + _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + + // create an amax reduction node + _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() + .set_mode(fe::ReductionMode_t::AMAX) + .set_compute_data_type(get_cudnn_fe_dtype(ctype))); + _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + } + } else { + _dz = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("dz") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})); + _rsigma = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("inv_var") + .set_dim({batch_dim, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + _mean = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("mean") + .set_dim({batch_dim, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + if (NormType == NVTE_Norm_Type::LayerNorm) { + auto norm_options = fe::graph::Layernorm_backward_attributes() + .set_saved_mean_and_inv_variance(_mean, _rsigma) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.layernorm_backward(_dz, _x, _gamma, norm_options); + std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); + _dbeta->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + } else { + auto norm_options = + fe::graph::Rmsnorm_backward_attributes().has_dbias(false).set_compute_data_type( + get_cudnn_fe_dtype(ctype)); + auto ret = _graph.rmsnorm_backward(_dz, _x, _gamma, _rsigma, norm_options); + std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); + if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned."); + } + _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + } + // Build the graph + this->_build(); +} + +void CudnnNormalizationPlan::_build() { + NVTE_CHECK(_graph.validate().is_good()); + NVTE_CHECK(_graph.build_operation_graph(_handle).is_good()); + NVTE_CHECK(_graph + .create_execution_plans( + {cudnn_frontend::HeurMode_t::A, cudnn_frontend::HeurMode_t::FALLBACK}) + .is_good()); + NVTE_CHECK(_graph.check_support(_handle).is_good()); + NVTE_CHECK( + _graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); +} + +std::vector CudnnNormalizationPlan::getWorkspaceShape() const { + return {static_cast(_graph.get_workspace_size())}; +} + +void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, + void* mean_dptr, void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + // Binding data pointers to graph tensors + _variant_pack = {{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_eps, eps_dptr}}; + + // layernorm should have valid mean_dptr and beta_dptr + if (mean_dptr && beta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_beta, beta_dptr}}); + + if (_zero_centered) + _variant_pack.insert( + {{_scalar_offset, reinterpret_cast(_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}}); + else + _variant_pack.insert({{_gamma, gamma_dptr}}); + + if (_fp8_out) + _variant_pack.insert( + {{_z_scale, z->scale.dptr}, {_amax, z->amax.dptr}, {_z_fp8, z->data.dptr}}); + else + _variant_pack.insert({{_z, z->data.dptr}}); + + // Execute the computation + NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); + NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); + if (_fp8_out) update_tensor_scale_inv(z, stream); +} + +void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, + void* rsigma_dptr, void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) { + // Binding data pointers to graph tensors + _variant_pack = { + {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; + + if (_zero_centered) + _variant_pack.insert({{_scalar_offset, reinterpret_cast(this->_scalar_dptr.get())}, + {_gamma_zero, gamma_dptr}}); + else + _variant_pack.insert({{_gamma, gamma_dptr}}); + + // layernorm should have valid mean_dptr and beta_dptr + if (mean_dptr && dbeta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_dbeta, dbeta_dptr}}); + + // Execute the computation + NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); + NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); +} + +NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( + NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, + DType itype, DType otype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned) { + const DType ctype = DType::kFloat32; + bool is_tuned = is_aligned && (batch_size % 4 == 0); + auto key = get_key(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, + zero_centered_gamma, is_tuned); + + auto it = normalizationPlanMap.find(key); + if (it != normalizationPlanMap.end()) { + return it->second.get(); + } + + std::unique_ptr plan; + if (NormBackend == NVTE_Norm_Backend::Cudnn) { + plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, + batch_size, hidden_size, sm_count, + zero_centered_gamma); + } else if (NormStage == NVTE_Norm_Stage::Forward) { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } else { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } + normalizationPlanMap.insert({key, std::move(plan)}); + return normalizationPlanMap[key].get(); +} + +bool& _cudnn_norm_fwd_flag() { + static bool flag = transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN"); + return flag; +} + +bool& _cudnn_norm_bwd_flag() { + static bool flag = transformer_engine::getenv("NVTE_NORM_BWD_USE_CUDNN"); + return flag; +} + +bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } +bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } + +} // namespace normalization +} // namespace transformer_engine + +void nvte_enable_cudnn_norm_fwd(bool enable) { + NVTE_API_CALL(nvte_enable_cudnn_norm_fwd); + transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable; +} + +void nvte_enable_cudnn_norm_bwd(bool enable) { + NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); + transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; +} diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h new file mode 100644 index 0000000000..8a8df63ba4 --- /dev/null +++ b/transformer_engine/common/normalization/common.h @@ -0,0 +1,382 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ +#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../cudnn_utils.h" +#include "../util/system.h" + +namespace transformer_engine { + +namespace normalization { + +namespace fe = cudnn_frontend; + +template +struct LaunchParams { + size_t workspace_bytes = 0; + size_t barrier_bytes = 0; + size_t dgamma_part_bytes = 0; + int multiprocessorCount; + cudaStream_t stream; + + KernelParamsType params; + + size_t getTotalWorkspaceBytes(const bool _is_layernorm = true) const { + return (workspace_bytes + barrier_bytes + size_t(_is_layernorm + 1) * dgamma_part_bytes); + } + void alignWorkspace(size_t alignment = 16) { + workspace_bytes = DIVUP(workspace_bytes, alignment) * alignment; + barrier_bytes = DIVUP(barrier_bytes, alignment) * alignment; + dgamma_part_bytes = DIVUP(dgamma_part_bytes, alignment) * alignment; + } +}; + +struct KernelParamsBase { + KernelParamsBase() + : ctas_per_col(0), + rows(0), + cols(0), + x(nullptr), + mu(nullptr), + rs(nullptr), + gamma(nullptr), + workspace(nullptr), + barrier(nullptr), + zero_centered_gamma(false) {} + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + // Size of CTA group. + int ctas_per_row; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void* x; + void* mu; + void* rs; + void* gamma; + + // Multi-CTA workspace in gmem. + void* workspace; + + // Multi-CTA sync barriers in gmem. + int* barrier; + + // Whether gamma is centered around 0 + bool zero_centered_gamma; +}; + +struct ForwardKernelParams : public KernelParamsBase { + ForwardKernelParams() + : KernelParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} + + // Output of LN FWD. + void* z; + void* beta; + float epsilon; + + // Scaling factor + void* scale; + int scale_byte_size; + + // Inverse of scaling factor + void* scale_inv; + + // AMax output + void* amax; + int amax_byte_size; + + // Whether to compute scale and amax + bool fp8_out; +}; + +struct BackwardKernelParams : public KernelParamsBase { + BackwardKernelParams() + : KernelParamsBase(), + dz(nullptr), + dbeta_part(nullptr), + dgamma_part(nullptr), + dx(nullptr), + dbeta(nullptr), + dgamma(nullptr) {} + + // Input: gradient wrt. LN FWD output. + void* dz; + + // Workspace for Wgrad pre-reduction. + void* dbeta_part; + void* dgamma_part; + + // Output: Dgrad. + void* dx; + // Output: Wgrad. + void* dbeta; + void* dgamma; +}; + +enum class NVTE_Norm_Backend { Te, Cudnn }; +enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; +enum class NVTE_Norm_Stage { Forward, Backward }; + +using TupleKeyType = std::tuple; +struct TupleHash { + size_t operator()(const TupleKeyType& t) const { + // Generate a hash for a tuple by combining the hashes of its entries + // See: https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine + size_t seed = 0; + std::hash hasher; + seed ^= hasher(std::get<0>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(std::get<1>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(std::get<2>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } +}; + +TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, + bool zero_centered_gamma, bool is_tuned); + +template +class TeNormalizationRegistry { + private: + using Function = std::function&, const bool)>; + std::unordered_map tuned_function_map; + std::unordered_map> general_function_map; + + TeNormalizationRegistry() = default; + + static TeNormalizationRegistry& getInstance() { + static TeNormalizationRegistry registry; + return registry; + } + + public: + static int registerFunction(TupleKeyType key, + void (*func)(LaunchParams&, const bool)) { + auto [general_key, batch_size, hidden_size, is_tuned] = key; + if (is_tuned) + getInstance().tuned_function_map.emplace(key, Function(func)); + else + getInstance().general_function_map[general_key].emplace(hidden_size, Function(func)); + return 0; + } + + static Function getKernel(TupleKeyType key) { + auto& instance = getInstance(); + auto [general_key, batch_size, hidden_size, is_tuned] = key; + if (is_tuned) { + auto it = instance.tuned_function_map.find(key); + if (it != instance.tuned_function_map.end()) return it->second; + } + if (instance.general_function_map.count(general_key) == 0) { + NVTE_ERROR("Unavailable kernel for this normalization config."); + } + auto& general_func_map = instance.general_function_map.at(general_key); + auto func_iter = general_func_map.lower_bound(hidden_size); + if (func_iter == general_func_map.end()) { + return general_func_map.rbegin()->second; // Hidden size is too big, need to use multi-CTA + } else { + return func_iter->second; + } + } + + TeNormalizationRegistry(const TeNormalizationRegistry&) = delete; + TeNormalizationRegistry& operator=(const TeNormalizationRegistry&) = delete; + TeNormalizationRegistry(TeNormalizationRegistry&&) = delete; + TeNormalizationRegistry& operator=(TeNormalizationRegistry&&) = delete; +}; + +class NormalizationPlanBase { + public: + virtual ~NormalizationPlanBase() = default; + virtual std::vector getWorkspaceShape() const = 0; + + virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) = 0; + + virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) = 0; + + private: + virtual void _build() = 0; +}; + +template +class TeNormalizationPlan : public NormalizationPlanBase { + public: + TeNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_tuned); + std::vector getWorkspaceShape() const override; + + void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, + void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + private: + void _set_workspace(); + void _build(); + + using KernelRegistry = TeNormalizationRegistry; + LaunchParams _launch_params; + std::function&, const bool)> _kernel; + + const bool _is_layernorm; +}; + +class CudnnNormalizationPlan : public NormalizationPlanBase { + public: + CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, + DType itype, DType otype, DType ctype, const size_t batch_size, + const size_t hidden_size, const size_t sm_count, + const bool zero_centered_gamma); + + std::vector getWorkspaceShape() const override; + + void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, + void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + private: + void _build() override; + + const bool _zero_centered, _fp8_out; + std::unique_ptr _scalar_dptr; + // FWD + std::shared_ptr _x, _gamma_zero, _scalar_offset, _gamma, _beta, + _eps, _mean, _rsigma, _z, _z_scale, _amax, _z_fp8; + // BWD + std::shared_ptr _dz, _dx, _dgamma, _dbeta; + + fe::graph::Graph _graph; + std::unordered_map, void*> _variant_pack; + cudnnHandle_t _handle; +}; + +class NormalizationPlanRegistry { + public: + // TODO thread-safe + static NormalizationPlanRegistry& getInstance() { + static NormalizationPlanRegistry instance; + return instance; + } + + NormalizationPlanBase* getNormalizationPlan(NVTE_Norm_Backend NormBackend, + NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, + const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, + const bool is_aligned); + + private: + NormalizationPlanRegistry() {} + NormalizationPlanRegistry(const NormalizationPlanRegistry&) = delete; + NormalizationPlanRegistry& operator=(const NormalizationPlanRegistry&) = delete; + + std::unordered_map, TupleHash> + normalizationPlanMap; +}; + +using byte = uint8_t; +using int32 = int32_t; +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; + +template +struct TypeToDType; + +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat32; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat16; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kBFloat16; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat8E4M3; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat8E5M2; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kInt32; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kByte; +}; + +#define IS_TUNED(x) (strcmp(#x, "tuned") == 0 ? 1 : 0) + +// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those +#define REGISTER_NORM_BASE(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, \ + CTYPE, FUNC_NAME) \ + static int \ + register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \ + TeNormalizationRegistry::registerFunction( \ + (get_key(NVTE_Norm_Type::NORM_TYPE, NVTE_Norm_Stage::NORM_STAGE, \ + (TypeToDType::value), (TypeToDType::value), \ + (TypeToDType::value), (TypeToDType::value), 0, HIDDEN_SIZE, \ + 0, IS_TUNED(LAUNCH_TYPE))), \ + FUNC_NAME) + +// For FP8 only +void ComputeScaleInv(void* scale, void* scale_inv); + +// Alignment check +template +bool is_ptr_aligned(const Args*... ptrs) { + return ((reinterpret_cast(ptrs) % Alignment == 0) && ...); +} + +bool use_cudnn_norm_fwd(); +bool use_cudnn_norm_bwd(); + +} // namespace normalization + +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/layer_norm/ln_kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h similarity index 89% rename from transformer_engine/common/layer_norm/ln_kernel_traits.h rename to transformer_engine/common/normalization/kernel_traits.h index a72726c325..0f8fea3f0b 100644 --- a/transformer_engine/common/layer_norm/ln_kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -4,16 +4,15 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_ +#ifndef TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_ +#define TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_ #include "../common.h" #include "../utils.cuh" -//////////////////////////////////////////////////////////////////////////////////////////////////// - namespace transformer_engine { -namespace layer_norm { +namespace normalization { + template struct Kernel_traits_base { @@ -28,8 +27,6 @@ struct Kernel_traits_base { enum { THREADS_PER_WARP = 32 }; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - template + +#include +#include +#include +#include +#include + +#include "../../common.h" +#include "../common.h" + +namespace transformer_engine { + +using namespace normalization; + +void layernorm_fwd(const Tensor& x, // BxSxhidden_size + const Tensor& gamma, // hidden_size + const Tensor& beta, // hidden_size + const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(gamma.data.shape == beta.data.shape); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); + + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(mu->data.dtype == DType::kFloat32); + + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + CheckInputTensor(beta, "beta"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*mu, "mu"); + CheckOutputTensor(*rsigma, "rsigma"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_fwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, + mu->data.dptr, rsigma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + z->data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + } + return; +} + +void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, + const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, + Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(mu.data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma.data.dtype == mu.data.dtype); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + NVTE_CHECK(mu.data.shape[0] == x.data.shape[0]); + NVTE_CHECK(mu.data.shape == rsigma.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + NVTE_CHECK(dbeta->data.shape == gamma.data.shape); + NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(mu, "mu"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + CheckOutputTensor(*dbeta, "dbeta"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, + dx->data.dptr, dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Backward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream); + } + return; +} +} // namespace transformer_engine + +void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size + const NVTETensor gamma, // hidden_size + const NVTETensor beta, // hidden_size + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + NVTE_API_CALL(nvte_layernorm_fwd); + using namespace transformer_engine; + layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + *reinterpret_cast(beta), epsilon, reinterpret_cast(z), + reinterpret_cast(mu), reinterpret_cast(rsigma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} + +void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_layernorm_bwd); + using namespace transformer_engine; + layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(mu), *reinterpret_cast(rsigma), + *reinterpret_cast(gamma), reinterpret_cast(dx), + reinterpret_cast(dgamma), reinterpret_cast(dbeta), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} diff --git a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh similarity index 97% rename from transformer_engine/common/layer_norm/ln_bwd_kernels.cuh rename to transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index dbd0025244..44078a040b 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -7,16 +7,15 @@ #ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ -#include "../utils.cuh" -#include "ln.h" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace layer_norm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; @@ -119,8 +118,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( } reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * rn; - mdyy_local = layer_norm::Get<1>::of(result) * rn; + mdy_local = Get<0>::of(result) * rn; + mdyy_local = Get<1>::of(result) * rn; Ivec dx[LDGS]; idx = row * Ktraits::VEC_COLS + c; @@ -203,7 +202,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; @@ -323,7 +322,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -424,8 +423,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne // Reduce over row reduce_t result = reducer.allreduce({mdy, mdyy}, sum); - mdy = layer_norm::Get<0>::of(result) * rn; - mdyy = layer_norm::Get<1>::of(result) * rn; + mdy = Get<0>::of(result) * rn; + mdyy = Get<1>::of(result) * rn; // Compute dx #pragma unroll @@ -507,7 +506,7 @@ template __global__ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; @@ -573,7 +572,7 @@ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_gener } } -} // namespace layer_norm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000000..d6e15dfc30 --- /dev/null +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -0,0 +1,331 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../../common.h" +#include "../common.h" +#include "../kernel_traits.h" +#include "ln_bwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &ln_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &ln_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +// Create general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu new file mode 100644 index 0000000000..e7fe7a201b --- /dev/null +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -0,0 +1,395 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "ln_fwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_tuned_kernel; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; + } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); + +// Create general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh similarity index 97% rename from transformer_engine/common/layer_norm/ln_fwd_kernels.cuh rename to transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index bd3741d1d1..3ec5543c3a 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -10,15 +10,16 @@ #include #include -#include "../utils.cuh" -#include "ln.h" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace layer_norm { +namespace normalization { using namespace transformer_engine; template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) { +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( + ForwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -92,8 +93,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( stats_t s = stats.compute(xf, rn); - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); + compute_t mu = Get<0>::of(s); + compute_t m2 = Get<1>::of(s); if (bidn == 0 && warp_n == 0 && lane == 0) { mu_ptr[row] = mu; @@ -150,7 +151,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -315,7 +316,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } } -} // namespace layer_norm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp new file mode 100644 index 0000000000..f6e36ae3c9 --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -0,0 +1,166 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include "../../common.h" +#include "../common.h" +#include "transformer_engine/normalization.h" + +namespace transformer_engine { + +using namespace normalization; + +void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, + Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + NVTE_CHECK(x.data.shape.size() == 2); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*rsigma, "rsigma"); + } + + Tensor empty; + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_fwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + z->data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr, nullptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + } + + return; +} + +void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, + Tensor *dx, Tensor *dgamma, Tensor *workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(rsigma.data.dtype == DType::kFloat32); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + } + + Tensor empty; + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Backward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, nullptr, dgamma->data.dptr, workspace->data.dptr, stream); + } + return; +} + +} // namespace transformer_engine + +void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size + const NVTETensor gamma, // hidden_size + const float epsilon, NVTETensor z, NVTETensor rsigma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_fwd); + using namespace transformer_engine; + rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} + +void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size + const NVTETensor x, // Nxhidden_size + const NVTETensor rsigma, // N, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_bwd); + using namespace transformer_engine; + rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(rsigma), *reinterpret_cast(gamma), + reinterpret_cast(dx), reinterpret_cast(dgamma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh similarity index 97% rename from transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh rename to transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index 92fd850baa..223ac7fd79 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -7,15 +7,15 @@ #ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ -#include "../utils.cuh" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace rmsnorm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; @@ -172,7 +172,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; @@ -276,7 +276,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( - BwdParams params) { + BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -430,8 +430,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ template -__global__ __launch_bounds__( - WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) { +__global__ +__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel( + BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; @@ -474,7 +475,7 @@ __global__ __launch_bounds__( } } -} // namespace rmsnorm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000000..309075c1ec --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -0,0 +1,206 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "rmsnorm_bwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &rmsnorm_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create rmsnorm tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +// Create rmsnorm general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu new file mode 100644 index 0000000000..73634fc2dd --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -0,0 +1,210 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "rmsnorm_fwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; + } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create rmsnorm tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +// Create rmsnorm general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh similarity index 98% rename from transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh rename to transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index c435ae3744..5965ffdc5d 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -10,15 +10,15 @@ #include #include -#include "../utils.cuh" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace rmsnorm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -143,7 +143,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -291,7 +291,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ } } -} // namespace rmsnorm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm.h b/transformer_engine/common/rmsnorm/rmsnorm.h deleted file mode 100644 index 8b4e1cf24e..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm.h +++ /dev/null @@ -1,89 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" -#include "../layer_norm/ln.h" - -namespace transformer_engine { -namespace rmsnorm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams : public transformer_engine::layer_norm::LaunchParams {}; -struct FwdParams : public transformer_engine::layer_norm::FwdParams {}; -struct BwdParams : public transformer_engine::layer_norm::BwdParams {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp deleted file mode 100644 index 9b143b2f85..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ /dev/null @@ -1,387 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "../common.h" -#include "rmsnorm.h" -#include "transformer_engine/rmsnorm.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp32 fp16 -fp32 fp32 fp32 bf16 -fp32 fp32 fp32 fp8 -fp16 fp32 fp16 fp8 -bf16 fp32 bf16 fp8 - -Remarks: -Input type = Weight type -Compute always in FP32 - -*/ - -namespace transformer_engine { - -namespace layer_norm { -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size); -} - -namespace rmsnorm { - -using namespace transformer_engine; - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && - is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && - BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -// //////////////////////////////////////////////////////////////////////////////////////////////////// - -inline size_t product(const std::vector &shape) { - return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); -} - -} // namespace rmsnorm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, - Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, - Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - auto ctype = DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(hidden_size == cols); - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - rmsnorm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - rmsnorm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = nullptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; -} - -void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, - Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream, - const int multiprocessorCount, Tensor *workspace, Tensor *barrier, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - - const auto rows = x.data.shape[0]; - const auto cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - rmsnorm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - rmsnorm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = nullptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = nullptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); -} - -} // namespace transformer_engine - -void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size - const NVTETensor gamma, // hidden_size - const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm_fwd); - using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size - const NVTETensor x, // Nxhidden_size - const NVTETensor rsigma, // N, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm_bwd); - using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), false); -} - -void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size - const NVTETensor gamma, // hidden_size - const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm1p_fwd); - using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} - -void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size - const NVTETensor x, // Nxhidden_size - const NVTETensor rsigma, // N, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm1p_bwd); - using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), true); -} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 3215a6a9d4..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,220 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "rmsnorm.h" -#include "rmsnorm_bwd_kernels.cuh" -#include "rmsnorm_kernel_traits.h" - -using namespace transformer_engine::rmsnorm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = - rmsnorm::Kernel_traits; - auto kernel = &rmsnorm_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = - Kernel_traits_finalize; - - auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel; - kernel_f<<>>( - launch_params.params); -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &rmsnorm_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create rmsnorm tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -// Create rmsnorm general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu deleted file mode 100644 index 3c8e121540..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ /dev/null @@ -1,227 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "rmsnorm.h" -#include "rmsnorm_fwd_kernels.cuh" -#include "rmsnorm_kernel_traits.h" - -using namespace transformer_engine::rmsnorm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); - } -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create rmsnorm tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); - -// Create rmsnorm general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8); - -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h b/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h deleted file mode 100644 index 26d7da6400..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h +++ /dev/null @@ -1,42 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ - -#include "../common.h" -#include "../layer_norm/ln_kernel_traits.h" -#include "../utils.cuh" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace transformer_engine { -namespace rmsnorm { - -template < - uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_, - typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_, - typename Base = - layer_norm::Kernel_traits_finalize > -struct Kernel_traits_finalize : public Base {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template > -struct Kernel_traits : public Base {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index fd6cc09de9..0b7df0b5a8 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -16,7 +16,6 @@ from jax.extend import ffi from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -82,7 +81,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -96,18 +95,15 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = out_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval + return out_aval, mu_aval, rsigma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm fwd outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, _, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, mu_aval, rsigma_aval, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, mu_aval, rsigma_aval @staticmethod @@ -151,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-2:] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -160,9 +156,6 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, beta] operand_shapes = [x_shape, g_shape, b_shape] @@ -174,15 +167,9 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -198,7 +185,7 @@ def impl(x, gamma, beta, zero_centered_gamma, epsilon): to describe implementation """ assert LayerNormFwdPrimitive.inner_primitive is not None - out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind( + out, mu, rsigma, _ = LayerNormFwdPrimitive.inner_primitive.bind( x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) return out, mu, rsigma @@ -377,39 +364,25 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) - wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = ( - transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - True, - kwargs["zero_centered_gamma"], - kwargs["epsilon"], - get_backward_sm_margin(), - ) + (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + True, + kwargs["zero_centered_gamma"], + kwargs["epsilon"], + get_backward_sm_margin(), ) wkspace_aval = dx_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = dx_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - dgamma_part_aval = dgamma_aval.update( - shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) - ) - dbeta_part_aval = dbeta_aval.update( - shape=dbeta_part_info[0], dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]) - ) return ( dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, - barrier_aval, - dgamma_part_aval, - dbeta_part_aval, ) @staticmethod @@ -417,9 +390,7 @@ def outer_abstract(*args, **kwargs): """ LayerNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = LayerNormBwdPrimitive.abstract( - *args, **kwargs - ) + dx_aval, dgamma_aval, dbeta_aval, _ = LayerNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval, dbeta_aval @staticmethod @@ -470,20 +441,14 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-4:] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - dbeta_part_aval.shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - jax_dtype_to_te_dtype(dbeta_part_aval.dtype), zero_centered_gamma, epsilon, sm_margin, @@ -496,7 +461,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): @staticmethod def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): assert LayerNormBwdPrimitive.inner_primitive is not None - dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, dbeta, _ = LayerNormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) return dx, dgamma, dbeta @@ -630,7 +595,7 @@ def abstract(x_aval, gamma_aval, **kwargs): hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -644,18 +609,15 @@ def abstract(x_aval, gamma_aval, **kwargs): wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = out_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, rsigma_aval, wkspace_aval, barrier_aval + return out_aval, rsigma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm fwd outer primitive abstract """ - out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, rsigma_aval, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, rsigma_aval @staticmethod @@ -688,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-2:] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -696,9 +658,6 @@ def lowering(ctx, x, gamma, *, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma] operand_shapes = [x_shape, g_shape] @@ -710,15 +669,9 @@ def lowering(ctx, x, gamma, *, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -734,7 +687,7 @@ def impl(x, gamma, epsilon): to describe implementation """ assert RmsNormFwdPrimitive.inner_primitive is not None - out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) + out, rsigma, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) return out, rsigma @staticmethod @@ -833,36 +786,28 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = core.raise_to_shaped(gamma_aval) - wkspace_info, barrier_info, dgamma_part_info, _ = ( - transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - False, - False, - kwargs["epsilon"], - get_backward_sm_margin(), - ) + (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + False, + False, + kwargs["epsilon"], + get_backward_sm_margin(), ) wkspace_aval = dx_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = dx_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - dgamma_part_aval = dgamma_aval.update( - shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) - ) - return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval + return dx_aval, dgamma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) + dx_aval, dgamma_aval, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval @staticmethod @@ -896,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-3:] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -904,12 +849,6 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), - ir.RankedTensorType.get( - dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) - ), ] operands = [dz, rsigma, x, gamma] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] @@ -921,15 +860,9 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - (0,), # no dbeta_part for RMSnorm jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -942,7 +875,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): @staticmethod def impl(dz, x, rsigma, gamma, epsilon): assert RmsNormBwdPrimitive.inner_primitive is not None - dx, dgamma, _, _, _ = RmsNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, _ = RmsNormBwdPrimitive.inner_primitive.bind( dz, x, rsigma, gamma, epsilon=epsilon ) return dx, dgamma @@ -1066,7 +999,7 @@ def abstract( assert gamma_aval.size == beta_aval.size - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size jax_dtype_to_te_dtype(x_aval.dtype), # in type @@ -1084,18 +1017,15 @@ def abstract( wkspace_aval = x_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = x_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval + return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm fwd (fp8 out) outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = LayerNormFwdFp8Primitive.abstract( + out_aval, mu_aval, rsigma_aval, updated_amax_aval, _ = LayerNormFwdFp8Primitive.abstract( *args, **kwargs ) return out_aval, mu_aval, rsigma_aval, updated_amax_aval @@ -1158,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-2:] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1168,9 +1098,6 @@ def lowering( ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, beta, amax, scale, scale_inv] operand_shapes = [ @@ -1189,15 +1116,9 @@ def lowering( batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -1215,7 +1136,7 @@ def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, to describe implementation """ assert LayerNormFwdFp8Primitive.inner_primitive is not None - out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( + out, mu, rsigma, updated_amax, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( x, gamma, beta, @@ -1394,7 +1315,7 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp rsigama_dtype = jnp.float32 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch_size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -1412,18 +1333,15 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp wkspace_aval = x_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = x_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval + return out_aval, rsigma_aval, amax_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm fwd (fp8 out) outer primitive abstract """ - out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) + out_aval, rsigma_aval, amax_aval, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) return out_aval, rsigma_aval, amax_aval @staticmethod @@ -1476,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-2:] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1485,9 +1403,6 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, amax, scale, scale_inv] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] @@ -1499,15 +1414,9 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -1525,7 +1434,7 @@ def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon): to describe implementation """ assert RmsNormFwdFp8Primitive.inner_primitive is not None - out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( + out, rsigma, amax, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) return out, rsigma, amax diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 02e6aaf9d5..64f3c467b6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -81,25 +81,18 @@ struct CustomCallNormDescriptor { size_t batch_size; size_t hidden_size; size_t wkspace_size; - size_t barrier_size; - Shape dgamma_part_shape; - Shape dbeta_part_shape; DType x_dtype; DType w_dtype; DType wkspace_dtype; - DType barrier_dtype; - DType dgamma_part_dtype; - DType dbeta_part_dtype; bool zero_centered_gamma; float eps; int sm_margin; }; -pybind11::bytes PackCustomCallNormDescriptor( - size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, - const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, - DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin); +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, DType x_dtype, DType w_dtype, + DType wkspace_dtype, bool zero_centered_gamma, + float eps, int sm_margin); struct SoftmaxDescriptor { size_t batch_size; diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 9bd9951916..845eb844e2 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -3,9 +3,9 @@ * * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/normalization.h" + #include "extensions.h" -#include "transformer_engine/layer_norm.h" -#include "transformer_engine/rmsnorm.h" namespace transformer_engine { namespace jax { @@ -25,40 +25,36 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; if (is_layer_norm) { auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr, - num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { + // TODO(Phuong): Verify and remove this check NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); + rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, + nullptr); } auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); } void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size, - size_t barrier_size, bool zero_centered_gamma, float eps, void *input, - DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, - DType out_dtype, void *workspace, DType work_dtype, void *barrier, - DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, - float *scale_inv, int sm_margin, cudaStream_t stream) { + bool zero_centered_gamma, float eps, void *input, DType in_dtype, + void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, + void *workspace, DType work_dtype, void *mu, void *rsigma, float *amax, + float *scale, float *scale_inv, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; auto workspace_shape = std::vector{workspace_size}; - auto barrier_shape = std::vector{barrier_size}; auto is_layer_norm = (bias) ? true : false; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); @@ -71,23 +67,21 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); if (is_layer_norm) { auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, num_sm, - workspace_tensor.data(), barrier_tensor.data()); + nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); + rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, + stream); } } @@ -96,20 +90,17 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf, Result_Type *output_buf, Result_Type *mu_buf, Result_Type *rsigma_buf, Result_Type *amax_out_buf, - Result_Type *wkspace_buf, Result_Type *barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_, - bool is_layer_norm, bool is_fp8) { + Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_, bool is_layer_norm, bool is_fp8) { auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type()); auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); - auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); auto *input = x_buf->untyped_data(); auto *weight = gamma_buf->untyped_data(); auto *output = (*output_buf)->untyped_data(); auto *rsigma = (*rsigma_buf)->untyped_data(); auto *workspace = (*wkspace_buf)->untyped_data(); - auto *barrier = (*barrier_buf)->untyped_data(); void *bias = nullptr; void *mu = nullptr; @@ -135,17 +126,15 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff auto x_size = product(x_buf->dimensions()); auto gamma_size = product(gamma_buf->dimensions()); auto wkspace_size = product((*wkspace_buf)->dimensions()); - auto barrier_size = product((*barrier_buf)->dimensions()); auto hidden_size = gamma_size; auto batch_size = x_size / gamma_size; float eps = static_cast(eps_); int sm_margin = static_cast(sm_margin_); - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); return ffi_with_cuda_error_check(); } @@ -154,11 +143,10 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type amax_out_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, bool zero_centered_gamma, double eps_, - int64_t sm_margin_) { + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, true, // is_layer_norm true // is_fp8 ); @@ -178,7 +166,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI .Ret() // rsigma .Ret() // amax_out .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -187,15 +174,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, bool zero_centered_gamma, double eps_, - int64_t sm_margin_) { + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, nullptr, // amax_buf nullptr, // scale_buf, nullptr, // scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, nullptr, // amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, true, // is_layer_norm false // is_fp8 ); @@ -211,7 +197,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI, .Ret() // mu .Ret() // rsigma .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -221,14 +206,14 @@ Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_T Buffer_Type amax_buf, Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type rsigma_buf, Result_Type amax_out_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, nullptr, // beta_buf, &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, nullptr, // mu_buf, - &rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf, - zero_centered_gamma, eps_, sm_margin_, + &rsigma_buf, &amax_out_buf, &wkspace_buf, zero_centered_gamma, + eps_, sm_margin_, false, // is_layer_norm true // is_fp8 ); @@ -246,7 +231,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, .Ret() // rsigma .Ret() // amax_out .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -254,8 +238,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, Result_Type output_buf, Result_Type rsigma_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, nullptr, // beta_buf, nullptr, // amax_buf, @@ -265,7 +249,7 @@ Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type nullptr, // mu_buf, &rsigma_buf, nullptr, // amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, false, // is_layer_norm false // is_fp8 ); @@ -279,7 +263,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI, .Ret() // output .Ret() // rsigma .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -303,50 +286,34 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; + TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - // initialize dBeta information here -- layernorm will modify but RMSnorm will not - std::vector dbeta_part_shape; if (is_layer_norm) { auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dbeta_tensor.data(), dummy_dgamma_part_tensor.data(), - dummy_dbeta_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); + dbeta_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, + nullptr); - dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape()); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), - xgrad_tensor.data(), wgrad_tensor.data(), dummy_dgamma_part_tensor.data(), - nullptr, num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); - - dbeta_part_shape = std::vector{0, 0}; + xgrad_tensor.data(), wgrad_tensor.data(), dummy_work_tensor.data(), num_sm, + zero_centered_gamma, nullptr); } auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()), - std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()), - std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); } void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, - size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape, bool zero_centered_gamma, float eps, void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, void *workspace, - DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, - void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, - DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin, - cudaStream_t stream) { + DType wkspace_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad, + void *dbeta, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -368,28 +335,23 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto workspace_shape = std::vector{wkspace_size}; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto barrier_shape = std::vector{barrier_size}; - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); - auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype); if (is_layer_norm) { auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype); - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dbeta_tensor.data(), dgamma_part_tensor.data(), dbeta_part_tensor.data(), - stream, num_sm, workspace_tensor.data(), barrier_tensor.data()); + dbeta_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, + stream); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), - xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream, - num_sm, workspace_tensor.data(), barrier_tensor.data()); + xgrad_tensor.data(), wgrad_tensor.data(), workspace_tensor.data(), num_sm, + zero_centered_gamma, stream); } } @@ -397,15 +359,11 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu Buffer_Type *mu_buf, Buffer_Type *rsigma_buf, Buffer_Type *gamma_buf, Result_Type *xgrad_buf, Result_Type *wgrad_buf, Result_Type *dbeta_buf, - Result_Type *wkspace_buf, Result_Type *barrier_buf, - Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_, - bool is_layer_norm) { + Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_, bool is_layer_norm) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type()); auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); - auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); - auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type()); auto *ograd = dz_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data(); @@ -414,62 +372,37 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu auto *xgrad = (*xgrad_buf)->untyped_data(); auto *wgrad = (*wgrad_buf)->untyped_data(); auto *workspace = (*wkspace_buf)->untyped_data(); - auto *barrier = (*barrier_buf)->untyped_data(); - auto *dgamma_part = (*dgamma_part_buf)->untyped_data(); void *mu = nullptr; void *dbeta = nullptr; - void *dbeta_part = nullptr; - auto dbeta_part_dtype = DType::kByte; if (is_layer_norm) { mu = (*mu_buf).untyped_data(); dbeta = (*dbeta_buf)->untyped_data(); - dbeta_part = (*dbeta_part_buf)->untyped_data(); - dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type()); } auto x_size = product(x_buf->dimensions()); auto gamma_size = product(gamma_buf->dimensions()); auto wkspace_size = product((*wkspace_buf)->dimensions()); - auto barrier_size = product((*barrier_buf)->dimensions()); auto hidden_size = gamma_size; auto batch_size = x_size / gamma_size; - Shape dgamma_part_shape; - auto dgamma_part_dims = (*dgamma_part_buf)->dimensions(); - std::vector dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end()); - dgamma_part_shape.from_vector(dgamma_parts_dims_vector); - - Shape dbeta_part_shape; - if (is_layer_norm) { - auto dbeta_part_dims = (*dbeta_part_buf)->dimensions(); - std::vector dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end()); - dbeta_part_shape.from_vector(dbeta_parts_dims_vector); - } else { - dbeta_part_shape.from_vector({0, 0}); - } - float eps = static_cast(eps_); int sm_margin = static_cast(sm_margin_); - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); return ffi_with_cuda_error_check(); } Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - Result_Type dgamma_part_buf, Result_Type dbeta_part_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf, - &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf, - &dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_, - sm_margin_, + &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, + zero_centered_gamma, eps_, sm_margin_, true // is_layer_norm ); } @@ -486,9 +419,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, .Ret() // wgrad .Ret() // dbeta .Ret() // wkspace - .Ret() // barrier - .Ret() // dgamma_part - .Ret() // dbeta_part .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -497,15 +427,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, Result_Type dgamma_part_buf, bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, nullptr, // mu_buf &rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf, nullptr, // dbeta_buf, - &wkspace_buf, &barrier_buf, &dgamma_part_buf, - nullptr, // dbeta_part_buf, - zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, false // is_layer_norm ); } @@ -520,8 +447,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI, .Ret() // xgrad .Ret() // wgrad .Ret() // wkspace - .Ret() // barrier - .Ret() // dgamma_part .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -540,7 +465,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto *rsigma = buffers[8]; auto *amax_out = buffers[9]; auto *workspace = buffers[10]; - auto *barrier = buffers[11]; NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive"); @@ -548,21 +472,18 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -573,7 +494,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto *mu = buffers[4]; auto *rsigma = buffers[5]; auto *workspace = buffers[6]; - auto *barrier = buffers[7]; float *amax = nullptr; float *scale = nullptr; @@ -583,20 +503,17 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto out_dtype = in_dtype; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -605,15 +522,9 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - auto dbeta_part_shape = desc.dbeta_part_shape; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = desc.dbeta_part_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; @@ -627,15 +538,10 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto *wgrad = buffers[6]; auto *dbeta = buffers[7]; auto *workspace = buffers[8]; - auto *barrier = buffers[9]; - auto *dgamma_part = buffers[10]; - auto *dbeta_part = buffers[11]; - - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -648,7 +554,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto *rsigma = buffers[6]; auto *amax_out = buffers[7]; auto *workspace = buffers[8]; - auto *barrier = buffers[9]; NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive."); void *bias = nullptr; @@ -658,20 +563,17 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -680,7 +582,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto *output = buffers[2]; auto *rsigma = buffers[3]; auto *workspace = buffers[4]; - auto *barrier = buffers[5]; void *bias = nullptr; void *mu = nullptr; @@ -692,20 +593,17 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = in_dtype; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -716,36 +614,24 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto *xgrad = buffers[4]; auto *wgrad = buffers[5]; auto *workspace = buffers[6]; - auto *barrier = buffers[7]; - auto *dgamma_part = buffers[8]; void *mu = nullptr; void *dbeta = nullptr; - void *dbeta_part = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - Shape dbeta_part_shape; - dbeta_part_shape.from_vector({0, 0}); auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = DType::kByte; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..ccc6921f43 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -32,24 +32,17 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shap return PackOpaque(desc); } -pybind11::bytes PackCustomCallNormDescriptor( - size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, - const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, - DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) { +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, DType x_dtype, DType w_dtype, + DType wkspace_dtype, bool zero_centered_gamma, + float eps, int sm_margin) { CustomCallNormDescriptor desc{}; desc.batch_size = batch_size; desc.hidden_size = hidden_size; desc.wkspace_size = wkspace_size; - desc.barrier_size = barrier_size; - desc.dgamma_part_shape.from_vector(dgamma_part_shape); - desc.dbeta_part_shape.from_vector(dbeta_part_shape); desc.x_dtype = x_dtype; desc.w_dtype = w_dtype; desc.wkspace_dtype = wkspace_dtype; - desc.barrier_dtype = barrier_dtype; - desc.dgamma_part_dtype = dgamma_part_dtype; - desc.dbeta_part_dtype = dbeta_part_dtype; desc.zero_centered_gamma = zero_centered_gamma; desc.eps = eps; desc.sm_margin = sm_margin; diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 6ce250432a..9b7e3d767a 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -10,9 +10,8 @@ #include #include #include -#include +#include #include -#include #include #include #include diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 583cd0f47a..b35b4434db 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -353,24 +353,23 @@ std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); auto mu_cu = MakeNvteTensor(mu); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + // This call populates workspace tensor with the required config + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); return {ln_out, mu, rsigma}; } @@ -394,24 +393,23 @@ std::vector te_layernorm_fwd(const paddle::Tensor &input, auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); auto mu_cu = MakeNvteTensor(mu); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + // This call populates workspace tensor with the required config + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); return {ln_out, mu, rsigma}; } @@ -424,7 +422,7 @@ std::vector te_layernorm_bwd(const paddle::Tensor &dz, const pad auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + TensorWrapper workspace; auto dz_cu = MakeNvteTensor(dz); auto x_cu = MakeNvteTensor(x); @@ -438,25 +436,18 @@ std::vector te_layernorm_bwd(const paddle::Tensor &dz, const pad auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), - num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + num_sm - sm_margin, zero_centered_gamma, dz.stream()); // Alloc space for Tensors. auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); - auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), - num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + num_sm - sm_margin, zero_centered_gamma, dz.stream()); return {dx, dgamma, dbeta}; } @@ -477,24 +468,21 @@ std::vector te_rmsnorm_fwd(const paddle::Tensor &input, auto gamma_cu = MakeNvteTensor(weight); auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config - + // This call populates workspace tensor with the required config nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); return {ln_out, rsigma}; } @@ -521,23 +509,21 @@ std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config + // This call populates workspace tensor with the required config nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); return {ln_out, rsigma}; } @@ -550,7 +536,7 @@ std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl auto dx = paddle::empty_like(x, x.dtype(), x.place()); auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - TensorWrapper workspace, barrier, dgamma_part; + TensorWrapper workspace; auto dz_cu = MakeNvteTensor(dz); auto x_cu = MakeNvteTensor(x); @@ -563,21 +549,17 @@ std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl // This call populates tensors with the required config. nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); + dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, + dz.stream()); // Alloc space for Tensors. auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); // Actual call to bwd kernel. nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); + dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, + dz.stream()); return {dx, dgamma}; } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 175a7b0e90..82f58b1eda 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,11 +28,10 @@ #include #include #include -#include +#include #include #include #include -#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 04274ae2ef..2574b84352 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -19,7 +19,7 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx = at::empty_like(x_); auto dgamma = at::empty_like(gamma_); auto dbeta = at::empty_like(gamma_); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + transformer_engine::TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); auto x_cu = makeTransformerEngineTensor(x_); @@ -31,32 +31,21 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dbeta_cu = makeTransformerEngineTensor(dbeta); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Alloc space for Tensors. auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), - dgamma_part.dtype()); - dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(), - dbeta_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {dx, dgamma, dbeta}; } @@ -88,9 +77,6 @@ std::vector layernorm_fwd_fp8_noalloc( const auto &weight_ = weight.contiguous(); const auto &bias_ = bias.contiguous(); - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - // Tensor dimensions size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); @@ -113,24 +99,22 @@ std::vector layernorm_fwd_fp8_noalloc( auto rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + transformer_engine::TensorWrapper workspace; + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Allocate workspaces auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); // Launch kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {ln_out, mu, rsigma}; } @@ -194,7 +178,7 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx = at::empty_like(x_); auto dgamma = at::empty_like(gamma_); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part; + transformer_engine::TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); auto x_cu = makeTransformerEngineTensor(x_); @@ -204,27 +188,21 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dgamma_cu = makeTransformerEngineTensor(dgamma); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Alloc space for Tensors. auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), - dgamma_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {dx, dgamma}; } @@ -255,9 +233,6 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a const int scale_inv_offset) { using namespace transformer_engine; - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; - // Tensor dimensions size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); @@ -277,24 +252,22 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a auto rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + transformer_engine::TensorWrapper workspace; + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Allocate workspaces auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); // Launch kernel - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {ln_out, rsigma}; } From e4c99b03707190c26aab128ccaf5422dde37e34d Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 11 Dec 2024 01:31:14 +0800 Subject: [PATCH 031/239] [JAX] Use default factory for not sharing mutable default values (#1364) * Bug Fix: Use default factory for not sharing mutable default values --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen --- transformer_engine/jax/praxis/module.py | 25 +++++++++++++++----- transformer_engine/jax/praxis/transformer.py | 9 +++++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index b82c0915e4..e5649bfe7c 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -4,6 +4,7 @@ """ Praxis Modules """ +from dataclasses import field from functools import partial from typing import Callable, Iterable, Sequence, Tuple, Union @@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False @@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): out_features: int = 512 kernel_axes: Tuple[str, ...] = () use_bias: bool = True - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index f2ac802f10..2ae212afb9 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -4,6 +4,7 @@ """ Praxis Modules related Transformer """ +from dataclasses import field from functools import partial from typing import Optional, Sequence, Tuple import warnings @@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): zero_centered_gamma: bool = False return_layernorm_output: bool = False use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False @@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer): dropout_rng_name: str = "dropout" mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False From 0e1d9faed1ef8d341614c31b2fa7694b4a9f39a5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Dec 2024 08:00:46 -0500 Subject: [PATCH 032/239] [JAX] Bug fix for distributed normalization (#1366) * fix ctx.aval_out indexing for workspace * add cudnn init to prepare phase of norm custom calls * add thread_local for norm registry instance --------- Signed-off-by: Phuong Nguyen --- .../common/normalization/common.h | 3 +-- .../jax/cpp_extensions/normalization.py | 12 +++++----- .../jax/csrc/extensions/pybind.cpp | 24 ++++++++++++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 8a8df63ba4..d1d56d5cc9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { class NormalizationPlanRegistry { public: - // TODO thread-safe static NormalizationPlanRegistry& getInstance() { - static NormalizationPlanRegistry instance; + static thread_local NormalizationPlanRegistry instance; return instance; } diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 0b7df0b5a8..69d7962b62 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -1088,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..a319b74d76 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -83,12 +83,24 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi; From e7bfc0c547d63332e4f8d65e606dc69f4c22ffbe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 12 Dec 2024 14:16:09 -0800 Subject: [PATCH 033/239] Add user to CI (#1371) Add Jeremy to ci users Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 586abd0541..86d22b7944 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -42,6 +42,7 @@ jobs: || github.actor == 'kocchop' || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' ) steps: - name: Check if comment is issued by authorized person From 1ae81903a16f274ccdfd199c91634ab9833e4c9a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Dec 2024 18:09:42 -0800 Subject: [PATCH 034/239] Fix an invalid reference in the doc (#1362) --- examples/pytorch/comm_gemm_overlap/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md index bb3ba209ed..fc8458844b 100644 --- a/examples/pytorch/comm_gemm_overlap/README.md +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -16,7 +16,7 @@ Forward and backward passes with layer weights distributed over all GPUs in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7] @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across groups in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2 +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2 # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3] From 1975ace44b3d4255e2c2e7aa0546d394ab1c9ce3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 14 Dec 2024 12:09:21 -0500 Subject: [PATCH 035/239] [JAX] Bug Fix: Softmax FFIs with correct Encapsulates (#1375) * softmax custom calls with correct encapsulates * rm jax deprecated features --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 6 +++--- transformer_engine/jax/cpp_extensions/base.py | 2 +- .../jax/cpp_extensions/normalization.py | 14 +++++++------- .../jax/cpp_extensions/softmax.py | 8 ++++---- .../jax/csrc/extensions/pybind.cpp | 19 ++++++++----------- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 44b396ad55..7f09e6f900 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval return out_aval diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3d88c1f078..3715e6f20c 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from functools import partial -from jax import core +from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 69d7962b62..8ad7ee4fcb 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): mu_rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) assert gamma_aval.size == beta_aval.size @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_dtype == rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = dbeta_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs): rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) hidden_size = gamma_aval.size @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index a12943f4c2..67053ecd8e 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor): assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod @@ -237,7 +237,7 @@ def backward_abstract( assert dz_aval.shape == softmax_out_aval.shape - dx_aval = core.raise_to_shaped(dz_aval) + dx_aval = dz_aval return dx_aval @staticmethod @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a319b74d76..a986b91b30 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -61,26 +61,23 @@ pybind11::dict Registrations() { dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); - dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFunction(DActLuDBiasCastTransposeHandler); - dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFunction(DGatedActLuCastTransposeHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization dict["te_layernorm_forward_ffi"] = From 0196ed4461ad561411aa828d1e9dc89a32ef7177 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Mon, 16 Dec 2024 15:39:47 -0800 Subject: [PATCH 036/239] Enabling FP8 all-gather for TE Float8Tensor when using Torch FSDP2 (#1358) * draft implementation of fsdp2 fp8 all gather Signed-off-by: Youngeun Kwon * fix the convergence issue Signed-off-by: Youngeun Kwon * Add warning Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable lint error Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the lint error Signed-off-by: Youngeun Kwon * fix lint error Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint error Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint error Signed-off-by: Youngeun Kwon * add comments Signed-off-by: Youngeun Kwon * add ref Signed-off-by: Youngeun Kwon * add related tests Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Youngeun Kwon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/pytorch/distributed/run_fsdp2_model.py | 181 ++++++++++++++++++ tests/pytorch/distributed/test_torch_fsdp2.py | 67 +++++++ .../pytorch/tensor/float8_tensor.py | 92 ++++++++- 4 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 tests/pytorch/distributed/run_fsdp2_model.py create mode 100644 tests/pytorch/distributed/test_torch_fsdp2.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9a11ccc008..4e52153db9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py new file mode 100644 index 0000000000..0f00a6717b --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -0,0 +1,181 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import argparse + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn, optim +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh +from contextlib import nullcontext + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(SimpleNet, self).__init__() + self.fc1 = te.Linear(input_size, hidden_size) + self.fc2 = te.Linear(hidden_size, output_size) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") + parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") + parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--iter", type=int, default=10, help="Number of iterations for forward pass" + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + # Adding hsdp_dim as a list argument, comma-separated + parser.add_argument( + "--sharding-dims", + type=int, + nargs="+", + help='FSDP/HSDP sharding dimensions ("replicate", "shard")', + ) + args = parser.parse_args(argv, namespace) + if args.sharding_dims: + assert len(args.sharding_dims) <= 2 + return args + + +sub_modules_to_wrap = [te.Linear] + + +def _train(args): + assert "TORCHELASTIC_RUN_ID" in os.environ + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + assert LOCAL_SIZE == WORLD_SIZE + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + device = torch.device(f"cuda:{LOCAL_RANK}") + + # FP8 Configuration + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext + build_model_context_args = {} + + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Build the model with the specified context + with build_model_context(**build_model_context_args): + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + else: + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + # Move the model to the correct device + + model.to(device) + + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") + # Creating a DeviceMesh for fully_shard + world_size = int(WORLD_SIZE) + device_ids = list(range(world_size)) + if LOCAL_RANK == 0: + print(f"sharding-dims:{args.sharding_dims}") + # Setup the sharding mesh for FSDP/HSDP + if args.sharding_dims == None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 1: + assert args.sharding_dims[0] == device_ids[-1] + 1 + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 2: # HSDP + assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 + mesh = init_device_mesh( + "cuda", + (args.sharding_dims[0], args.sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + + # Apply FSDP/HSDP + custom_attrs = save_custom_attrs(model) + for sub_module in model.modules(): + if any( + isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, mesh=mesh) + fully_shard(model, mesh=mesh) + restore_custom_attrs(model, custom_attrs) + + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + for iteration in range(args.iter): + # Zero the parameter gradients + optimizer.zero_grad() + input_data = torch.randn(args.batch_size, args.input_size).to(device) + output = model(input_data) + target = torch.randn(args.batch_size, args.output_size).to(device) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + + dist.destroy_process_group() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Done...") + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py new file mode 100644 index 0000000000..3c9197c322 --- /dev/null +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import subprocess +from pathlib import Path +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import torch +from packaging.version import Version as PkgVersion + + +def get_torch_version(): + """Get pytorch version from __version__""" + + def get_torch_version_str(): + import torch + + return str(torch.__version__) + + return PkgVersion(get_torch_version_str()) + + +if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs.") + +if torch.cuda.device_count() % 2 != 0: + pytest.skip("Number of device should be divided by 2.") + +if not get_torch_version() >= PkgVersion("2.4"): + pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(fp_init, sharding_dims): + test_path = TEST_ROOT / "run_fsdp2_model.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + if fp_init: + test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: + test_cmd += ["--sharding-dims", str(sharding_dims[0])] + elif len(sharding_dims) == 2: + test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] + else: + assert False + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if result.returncode != 0: + raise AssertionError(result.stderr.decode()) + + +all_boolean = [True, False] +sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] + + +@pytest.mark.parametrize("sharding_dims", sharding_dims) +@pytest.mark.parametrize("fp8_init", all_boolean) +def test_distributed(fp8_init, sharding_dims): + if fp8_init and not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(fp8_init, sharding_dims) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7ace68a222..414e819f53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -24,6 +24,19 @@ aten = torch.ops.aten updated_fp8_params = {} +_ops_to_preserve_subclass_in_fsdp2 = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, +} + def _make_fp8_attr_property_funcs(name: str) -> Any: """Make accessors for an FP8 attribute @@ -430,6 +443,37 @@ def __new__( return self + def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument + """ + A hook function used in torch fsdp2, called before all-gather + return (all-gather input), (metadata) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + + return (self._data,), (self,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, # pylint: disable=unused-argument + *, + out: Optional[torch.Tensor] = None, + ): + """ + A hook function used in torch fsdp2, called after all-gather + return (Float8Tensor class instance of all-gathered input), (Things to free after forward) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + (data,) = all_gather_outputs + (sample,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + return None + return Float8Tensor.make_like(sample, data=data), all_gather_outputs + @classmethod def make_like( cls, @@ -902,7 +946,53 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8Tensor.make_like(tensor, data=data_view) - # Default case + # Related to FSDP2 + if func == aten.split.Tensor: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + if func == aten.new_zeros.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.as_strided.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.detach.default: + return cls.detach(args[0]) + if func == torch.ops.aten.clone.default: + return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: + # Implementation in the superclass (QuantizedTensor) returns a proper output + pass + elif func in _ops_to_preserve_subclass_in_fsdp2: + # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 + warnings.warn( + f"A function call({func}) in {cls} may not return {cls} tensor as an output. It" + " might cause an error in torch FSDP2!" + ) + else: + pass + return super().__torch_dispatch__(func, types, args, kwargs) @classmethod From f4f35c2f715e8c219ee4f76de2b9e768af062cfe Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 16 Dec 2024 19:57:44 -0800 Subject: [PATCH 037/239] [common] Add max_t support for KV in THD (#1370) add max_t for KV Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index f242502261..b706eadace 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -661,6 +661,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); + sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { From 7f5c784e32391670cd4661f61edbca7912916a6c Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Tue, 17 Dec 2024 15:41:40 +0800 Subject: [PATCH 038/239] [JAX] Fused attention unit tests fixes and refinements (#1352) * Add util functions to attn_mask_type Signed-off-by: Reese Wang * Add util functions to qkv_layout Signed-off-by: Reese Wang * Fix THD cross reference code Signed-off-by: Reese Wang * Remove explicit segment_pad, encoding it to segment_ids Signed-off-by: Reese Wang * Add jax.jit, replace _token with segment_ids, rename bias shape enum Signed-off-by: Reese Wang * Add comment for make_mask Signed-off-by: Reese Wang * Clean code Signed-off-by: Reese Wang * Add doc strings for the added functions Signed-off-by: Reese Wang * Remove cache for fa deterministic which causes UT failed Signed-off-by: Reese Wang * Rename fixture to avoid conflict Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/conftest.py | 2 +- tests/jax/test_distributed_fused_attn.py | 6 +- tests/jax/test_fused_attn.py | 227 ++++++++++-------- tests/jax/utils.py | 16 +- transformer_engine/jax/attention.py | 99 +++++--- .../jax/cpp_extensions/attention.py | 3 +- 6 files changed, 201 insertions(+), 152 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index ccb6690a87..5bb86c6081 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -20,7 +20,7 @@ def clear_live_arrays(): @pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): +def enable_fused_attn_after_hopper(): """ Enable fused attn for hopper+ arch. Fused attn kernels on pre-hopper arch are not deterministic. diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index e194a228d2..1538062975 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -20,7 +20,6 @@ from utils import ( make_causal_mask, make_self_mask, - assert_tree_like_allclose, assert_allclose, print_debug_tensor_stats, ) @@ -32,7 +31,6 @@ AttnMaskType, QKVLayout, QKVFormat, - get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, CPStrategy, @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn( dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape - qkv_format = get_qkv_format(qkv_layout) + qkv_format = qkv_layout.get_qkv_format() batch, seqlen, num_head, hidden = data_shape @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient _, max_seq_len, num_heads, _ = data_shape gradient_multiplier = max_seq_len * num_heads - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + if attn_mask_type.is_causal(): gradient_multiplier /= 10 ret_valid = func(*args, **kwargs) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index af05538ef5..759ea893ef 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -28,7 +28,6 @@ QKVFormat, fused_attn, fused_attn_thd, - get_qkv_format, make_swa_mask, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -50,6 +49,7 @@ def init(): yield +@partial(jax.jit, static_argnums=(5, 6, 7, 9)) def general_dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -102,29 +102,36 @@ def general_dot_product_attention( return context -def is_causal_mask(mask: AttnMaskType): - """ - Check if the mask is a causal mask - """ - return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] - - -def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: +@jax.jit +def make_causal_mask( + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike = None, + segment_pos_kv: ArrayLike = None, +) -> Array: """ Create inverse padded causal mask where `True` means allowing the corresponding position to participate in attention and `False` means masking out that position. + If segment_pos is not provided, aragne of the segment_ids will be applied. """ - q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) - kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) - inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal) return inv_causal_mask +@partial(jax.jit, static_argnums=(4, 5)) def make_mask( - q_token: ArrayLike, - kv_token: ArrayLike, - segment_pad_q: ArrayLike, - segment_pad_kv: ArrayLike, + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike, + segment_pos_kv: ArrayLike, attn_mask_type: AttnMaskType, window_size: Optional[Tuple[int, int]] = None, ) -> Array: @@ -132,18 +139,31 @@ def make_mask( Create attention mask based on mask type. A `True` value in the mask means masking out the corresponding position and a `False` value means allowing that position to participate in attention. + + - segment_ids should start with 1, and using 0s for the paddings. + Expected that each segment starts without paddings. + - segment_pos marks the token position in the segments. + + A example pair of segments_ids and segment_pos: + segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] + segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ inv_mask = make_attention_mask( - q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) + segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) - if is_causal_mask(attn_mask_type): - inv_causal_mask = make_causal_mask(q_token, kv_token) - inv_mask = combine_masks(inv_causal_mask, inv_mask) - if segment_pad_q is not None and segment_pad_kv is not None: - inv_pad_mask = make_attention_mask( - segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) + if attn_mask_type.is_causal(): + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask( + segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) - inv_mask = combine_masks(inv_pad_mask, inv_mask) + inv_mask = combine_masks(inv_causal_mask, inv_mask) if window_size is not None: max_seqlen_q = inv_mask.shape[-2] @@ -157,7 +177,8 @@ def make_mask( return mask -def get_seqlens_and_offsets(segment_ids, segment_pad): +@jax.jit +def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) @@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): def _find_offsets(x): same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) - first_column = jnp.ones((x.shape[0], 1), dtype=bool) + first_column = x[..., :1] != 0 same_as_previous = jnp.hstack((first_column, same_as_previous)) return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))( same_as_previous @@ -173,13 +194,9 @@ def _find_offsets(x): offsets = _find_offsets(segment_ids) offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - if segment_pad is not None: - segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids) - padding_aware_seqlen = bincount_vmap(segment_id_with_paddings) - output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1) - else: - output = jnp.insert(seqlens, -1, values=0, axis=-1) - return output, offsets + seqlens = jnp.insert(seqlens, -1, values=0, axis=-1) + seqlens = jnp.where(seqlens, seqlens, -1) + return seqlens, offsets @jax.jit @@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): query, key, value, - bias=bias, - mask=mask, + bias, + mask, deterministic=not kwargs["is_training"], scale_factor=kwargs["scaling_factor"], dropout_rate=kwargs["dropout_probability"], @@ -228,7 +245,6 @@ def customcall_fused_dpa( TE customcall dot product attention implementation """ qkv_layout = kwargs["qkv_layout"] - is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD match qkv_layout: case QKVLayout.BS3HD | QKVLayout.T3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) @@ -242,7 +258,7 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not is_thd: + if not qkv_layout.is_thd(): kwargs.pop("max_segments_per_seq") return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_thd( @@ -262,10 +278,10 @@ class BiasShape(Enum): Enum class to represent the different bias shapes used in the fused attention. """ - BIAS_1HSS = "1HSS" - BIAS_B1SS = "B1SS" - BIAS_BHSS = "BHSS" - BIAS_11SS = "11SS" + _1HSS = "1HSS" + _B1SS = "B1SS" + _BHSS = "BHSS" + _11SS = "11SS" @dataclass @@ -300,18 +316,12 @@ def _get_max_segments_per_sequence(self): def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available - if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ - AttnMaskType.PADDING_MASK, - AttnMaskType.PADDING_CAUSAL_MASK, - ]: + if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") - qkv_format = get_qkv_format(self.qkv_layout) - if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") - - if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: if self.num_heads_q != self.num_heads_kv: pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") @@ -339,15 +349,11 @@ def _check_configs(self): if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS - and self.bias_shape != BiasShape.BIAS_1HSS + and self.bias_shape != BiasShape._1HSS ): - if self.attn_mask_type not in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - ]: + if self.attn_mask_type.is_padding(): pytest.skip( - "B1SS, BHSS and 11SS bias shapes are only supported for " - "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." + "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask" ) elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip( @@ -370,18 +376,18 @@ def _setup_inputs(self): if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None - elif self.bias_shape == BiasShape.BIAS_1HSS: + elif self.bias_shape == BiasShape._1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_B1SS: + elif self.bias_shape == BiasShape._B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_BHSS: + elif self.bias_shape == BiasShape._BHSS: bias_shape = ( self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv, ) - elif self.bias_shape == BiasShape.BIAS_11SS: + elif self.bias_shape == BiasShape._11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") @@ -391,7 +397,7 @@ def _setup_inputs(self): self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.bias_shape == BiasShape._1HSS: self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for @@ -408,10 +414,10 @@ def _setup_inputs(self): else: self.bias = None - if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: - pad_ratio = 0.0 - else: + if self.attn_mask_type.is_padding(): pad_ratio = 0.3 + else: + pad_ratio = 0.0 def gen_valid(bs, max_seqlen, pad_ratio): pad_len = int(max_seqlen * pad_ratio) @@ -425,6 +431,8 @@ def generate_random_segment_ids( rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad segment_ids = np.zeros((batch_size, sequence_length), dtype=int) + segment_pos = np.zeros((batch_size, sequence_length), dtype=int) + # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad segment_pad = np.zeros((batch_size, sequence_length), dtype=int) @@ -440,58 +448,62 @@ def generate_random_segment_ids( break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id + segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: num_valid = rng.integers(1, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 segment_pad[i, current_pos:sequence_length] = 1 - return segment_ids, segment_pad - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: + segment_ids, segment_pos, segment_pad = map( + jnp.asarray, [segment_ids, segment_pos, segment_pad] + ) + segment_ids = jnp.where(segment_pad, 0, segment_ids) + return segment_ids, segment_pos, segment_pad + + if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 - self.token_q, self.segment_pad_q = generate_random_segment_ids( + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - # TODO(rewang): Check if qkvpacked supported different q/kv - # TODO(rewang): Causal with different q/kv segment_id fails - if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): - self.token_kv = self.token_q - self.segment_pad_kv = self.segment_pad_q + if self.qkv_layout == QKVLayout.T3HD: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q else: - self.token_kv, self.segment_pad_kv = generate_random_segment_ids( + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, ) - self.pad_q = self.segment_pad_q - self.pad_kv = self.segment_pad_kv + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 - self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) - self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) - self.segment_pad_q = self.segment_pad_kv = None + self.segment_ids_q, self.pad_q = gen_valid( + self.batch_size, self.max_seqlen_q, pad_ratio + ) + self.segment_ids_kv, self.pad_kv = gen_valid( + self.batch_size, self.max_seqlen_kv, pad_ratio + ) + self.segment_pos_q = self.segment_pos_kv = None + self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None + # For reference code self.mask = make_mask( - self.token_q, - self.token_kv, - self.segment_pad_q, - self.segment_pad_kv, + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, self.attn_mask_type, self.window_size, ) - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets( - self.token_q, self.segment_pad_q - ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( - self.token_kv, self.segment_pad_kv - ) + if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None self.mask_for_customcall = self.mask self.dropout_rng = dropout_key if self.dropout_prob > 0 else None @@ -547,13 +559,11 @@ def test_backward(self): """ self._setup_inputs() - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS: - pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.") def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q - if is_causal_mask(self.attn_mask_type): + if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient ret_valid = jnp.where( @@ -586,7 +596,7 @@ def grad_func(func, *args, **kwargs): } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) + arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -629,7 +639,7 @@ def check_dqkv(primitive, reference, pad): check_dqkv(primitive_dk, reference_dk, self.pad_kv) check_dqkv(primitive_dv, reference_dv, self.pad_kv) - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: + if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -658,16 +668,6 @@ def check_dqkv(primitive, reference, pad): ) -@pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"), - ], -) @pytest.mark.parametrize( "attn_mask_type", [ @@ -736,6 +736,16 @@ class TestFusedAttn: pytest.param(False, id="INFERENCE"), ], ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), + ], + ) def _test_forward( b, s_q, @@ -779,6 +789,13 @@ def _test_forward( runner.test_forward() @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) def test_backward( b, s_q, diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 78a6225e1f..242bafa5e2 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -19,7 +19,11 @@ from jax import nn as jax_nn from jax import random as jax_random -from transformer_engine.jax.attention import AttnMaskType, make_swa_mask +from transformer_engine.jax.attention import ( + AttnMaskType, + canonicalize_attn_mask_type, + make_swa_mask, +) from transformer_engine.jax.fp8 import DType as TEDType PRNGKey = Any @@ -913,15 +917,7 @@ def apply_swa_mask( window_size: Tuple[int, int] = (-1, -1), ) -> Array: """Apply the sliding window mask to a given mask""" - mask_map = { - "no_mask": AttnMaskType.NO_MASK, - "padding": AttnMaskType.PADDING_MASK, - "causal": AttnMaskType.CAUSAL_MASK, - "padding_causal": AttnMaskType.PADDING_CAUSAL_MASK, - "causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - "padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - } - _attn_mask_type = mask_map.get(attn_mask_type, None) + _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 3ecc9bcd75..53451b6a78 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -46,6 +46,42 @@ class AttnMaskType(Enum): CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK + def is_causal(self): + """Returns True if the mask is a causal mask""" + return self in [ + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_padding(self): + """Returns True if the mask includes padding""" + return self in [ + AttnMaskType.PADDING_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_bottom_right(self): + """Returns True if the causal mask is calculated from the bottom-right section""" + return self in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + +class QKVFormat(Enum): + """ + SBHD: q,k,v memory layout with [s, b, ..., h, d] + BSHD: q,k,v memory layout with [b, s, ..., h, d] + THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. + """ + + SBHD = NVTE_QKV_Format.NVTE_SBHD + BSHD = NVTE_QKV_Format.NVTE_BSHD + THD = NVTE_QKV_Format.NVTE_THD + class QKVLayout(Enum): """ @@ -66,17 +102,35 @@ class QKVLayout(Enum): THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD - -class QKVFormat(Enum): - """ - SBHD: q,k,v memory layout with [s, b, ..., h, d] - BSHD: q,k,v memory layout with [b, s, ..., h, d] - THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. - """ - - SBHD = NVTE_QKV_Format.NVTE_SBHD - BSHD = NVTE_QKV_Format.NVTE_BSHD - THD = NVTE_QKV_Format.NVTE_THD + def get_qkv_format(self): + """ + Return the corresponding qkv_format (BSHD, SBHD, THD) + """ + return QKVFormat(nvte_get_qkv_format(self.value)) + + def is_qkvpacked(self): + """ + Return True if the query, key, value is packed + """ + return self in [QKVLayout.BS3HD, QKVLayout.T3HD] + + def is_kvpacked(self): + """ + Return True if the key, value is packed + """ + return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD] + + def is_separate(self): + """ + Return True if the query, key, value are three separate tensors + """ + return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD] + + def is_thd(self): + """ + Return True if the layout belongs to THD + """ + return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] class CPStrategy(Enum): @@ -92,13 +146,6 @@ class CPStrategy(Enum): RING = 2 -def get_qkv_format(qkv_layout): - """ - Get qkv_format from qkv_layout - """ - return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) - - def make_swa_mask( max_seqlen_q: int, max_seqlen_kv: int, @@ -136,12 +183,8 @@ def make_swa_mask( swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) if window_size is None: return swa_mask - bottom_right_masks = [ - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - ] left_window, right_window = window_size - if attn_mask_type in bottom_right_masks: + if attn_mask_type.is_bottom_right(): if left_window < 0: left_window = max_seqlen_kv if right_window < 0: @@ -310,7 +353,7 @@ def fused_attn( (jnp.ndarray): The output tensor from the fused attention. """ assert ( - get_qkv_format(qkv_layout) != QKVFormat.THD + not qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format." # Check inputs qkv @@ -327,11 +370,7 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - ]: + if not attn_mask_type.is_padding(): batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -448,7 +487,7 @@ def fused_attn_thd( QKVLayout.T3HD, 0.125, 0, True, 3) """ assert ( - get_qkv_format(qkv_layout) == QKVFormat.THD + qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." # Check inputs qkv diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6591861057..f3dfca21ef 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce, cache +from functools import partial, reduce import operator import os from typing import Optional, Tuple @@ -133,7 +133,6 @@ def get_fused_attn_backend(self): ) @staticmethod - @cache def is_non_deterministic_allowed(): """Check if non-deterministic kernels are allowed""" return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) From 83dac8cf30d8abe2af421eb82ffd1c5a4fc859cb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:15:37 -0800 Subject: [PATCH 039/239] [PyTorch] Add weights_only=False for torch.load (#1374) add weights_only=False for torch.load Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_float8tensor.py | 2 +- tests/pytorch/test_sanity.py | 2 +- tests/pytorch/test_torch_save_load.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 51f4c695dc..a25ffa773c 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -339,7 +339,7 @@ def test_serialization( del x_fp8, byte_stream # Deserialize tensor - x_fp8 = torch.load(io.BytesIO(x_bytes)) + x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False) del x_bytes # Check results diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 4f057c12fe..32d517460a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1101,7 +1101,7 @@ def get_model(dtype, config): del block block = get_model(dtype, config) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) torch.set_rng_state(_cpu_rng_state_new) torch.cuda.set_rng_state(_cuda_rng_state_new) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 7bf8fb99d5..be77109cb7 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -124,7 +124,7 @@ def forward(self, inp, weight): torch.save(model_in.state_dict(), tmp_filename) model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename)) + model_out.load_state_dict(torch.load(tmp_filename, weights_only=False)) model_out.eval() # scaling fwd @@ -263,7 +263,7 @@ def test_fp8_model_checkpoint( # to load the fp8 metadata before loading tensors. # # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that loaded model matches saved model @@ -450,7 +450,7 @@ def train_step( torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that new model's FP8 metadata matches saved model From f033498f6c941b190c869bfa09310c2de3efd2c9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:15:47 -0800 Subject: [PATCH 040/239] [PyTorch] Fix get_swa_mask() for padding masks (#1281) * WIP: fix get_swa_mask for padding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix mask type setting Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix the order of checking valid swa and changing mask type Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revamp to get full mask Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 28 +-- transformer_engine/pytorch/attention.py | 227 ++++++++++++-------- 2 files changed, 157 insertions(+), 98 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4e995dabb1..dea31b5971 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "swa_6_0": ModelConfig( + 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_1": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8c529c58d0..be0d176520 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1024,27 +1024,51 @@ def swap_key_value_dict(self, batch_indices): @torch.no_grad() -def get_swa_mask( - window_size: Tuple[int, int], +def get_full_mask( max_seqlen_q: int, max_seqlen_kv: int, attn_mask_type: str = "no_mask", - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + window_size: Tuple[int, int] = None, + attention_type: str = "self", + bottom_right_alignment: bool = True, ) -> torch.Tensor: """ - Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. - For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner, - and for other mask types, the bottom right corner. + Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, + `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends + on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: + + attn_mask_type output shape diagonal alignment + -------------------------------------------------------------------------------------------- + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right + arbitrary same as attention_mask follow bottom_right_alignment + + .. note:: + + For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right + diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, + i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, + max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( + [[False, False, True, True], [False, False, False, False]], + [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] + shape and is,:: + + [[[False, False, False, True], + [False, False, False, True], + [ True, True, True, True], + [ True, True, True, True]], + [[False, True, True, True], + [False, True, True, True], + [False, True, True, True], + [False, True, True, True]]] Parameters ---------- - window_size: Tuple[int, int] - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. max_seqlen_q: int Maximum sequence length for queries. max_seqlen_kv: int @@ -1052,33 +1076,105 @@ def get_swa_mask( attn_mask_type: str, default = `no_mask` Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` - Boolean tensor(s) used to mask out attention softmax input. + Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention + for the requirements of `attention_mask` for different `attn_mask_type`s. + window_size: Tuple[int, int], default = `None` + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window and causal mask specifically. Both `causal` and `causal_bottom_right` masks + map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on + `attn_mask_type`. + attention_type: str, default = "self" + Attention type, {"self", "cross"} + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the sliding window attention to the bottom right (`True`) + or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly + specifies "causal" or "causal_bottom_right". Returns ---------- + attn_mask_type: str + For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` attention_mask: torch.Tensor - Combined `attention_mask` (input) and sliding window attention mask. - The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None; - else, the same shape as input `attention_mask`. + The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` + actual_seqlens_q: torch.Tensor + For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For other masks, `None`. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + For other masks, `None`. """ - mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda") - if attn_mask_type in ["causal"]: - left = window_size[0] if window_size[0] != -1 else max_seqlen_q - right = window_size[1] if window_size[1] != -1 else max_seqlen_q - mask_upper = torch.triu(mask, diagonal=-left) - mask_lower = torch.tril(mask_upper, diagonal=right) - else: - left = window_size[0] if window_size[0] != -1 else max_seqlen_kv - right = window_size[1] if window_size[1] != -1 else max_seqlen_kv - mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left) - mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right) - attn_mask_type = "arbitrary" - mask = mask_lower.logical_not() + # perform basic checks + change_type = window_size is not None and ( + window_size[0] != -1 or window_size[1] not in [-1, 0] + ) + if window_size is None: + window_size = (-1, -1) + if "causal" in attn_mask_type: + window_size = (window_size[0], 0) + window_size = ( + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1], + ) + + # apply padding mask + actual_seqlens_q = None + actual_seqlens_kv = None + if "padding" in attn_mask_type: + if attention_type == "self": + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + m = attention_mask.logical_not() + actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) + actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + + # apply SWA mask + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + swa_left = None + swa_right = None + if attn_mask_type == "causal_bottom_right" or ( + attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment + ): + swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] + swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] + elif attn_mask_type in ["causal", "padding_causal"] or ( + attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment + ): + swa_left = mask - window_size[0] + swa_right = mask + window_size[1] + elif attn_mask_type == "padding_causal_bottom_right" or ( + attn_mask_type == "padding" and bottom_right_alignment + ): + batch_size = attention_mask.shape[0] + swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0] + ).view(batch_size, 1, 1, 1) + swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1] + ).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not( + torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) + ) if attention_mask is not None: - mask = torch.logical_and(attention_mask, mask) - return attn_mask_type, mask + attention_mask = torch.logical_or(swa_mask, attention_mask) + else: + attention_mask = swa_mask + + # change mask type + if change_type: + attn_mask_type = "arbitrary" + + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv @torch.no_grad() @@ -4733,6 +4829,7 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + window_size: Optional[Tuple[int, int]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -4752,53 +4849,15 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) - if "padding" in attn_mask_type: - if self.attention_type == "self": - assert attention_mask.shape == ( - batch_size, - 1, - 1, - max_seqlen_q, - ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" - attention_mask = torch.logical_or( - attention_mask.squeeze(1).unsqueeze(3), attention_mask - ) - else: - assert ( - len(attention_mask) == 2 - and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) - and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) - ), ( - "attention_mask should be a tuple of two tensors with shapes " - "[b, 1, 1, sq] and [b, 1, 1, skv]!" - ) - attention_mask = torch.logical_or( - attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] - ) - mask = attention_mask.squeeze(1).logical_not() - actual_seqlens_q = mask[:, :, 0].sum(dim=1) - actual_seqlens_kv = mask[:, 0, :].sum(dim=1) - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - if attn_mask_type == "padding_causal": - attention_mask = torch.logical_or( - torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), - attention_mask, - ) - if attn_mask_type == "padding_causal_bottom_right": - attention_mask = torch.logical_or( - torch.where( - mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - < 0, - 1, - 0, - ), - attention_mask, - ) + + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -8274,12 +8333,6 @@ def forward( ) if use_unfused_attention: - if window_size is not None and ( - window_size[0] != -1 or window_size[1] not in [-1, 0] - ): - attn_mask_type, attention_mask = get_swa_mask( - window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask - ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -8291,6 +8344,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -8304,6 +8358,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, From a3b32ec6cb15dac8dc96ae03e40f51dfd072f195 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Dec 2024 10:47:36 -0500 Subject: [PATCH 041/239] [JAX] Move parallel encoder tests to L0 distributed test set. (#1356) * Move test distributed encoder to L0 distributed test suit --------- Signed-off-by: Phuong Nguyen Co-authored-by: Reese Wang --- qa/L0_jax_distributed_unittest/test.sh | 15 +++++++++++++++ qa/L0_jax_unittest/test.sh | 3 +-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 qa/L0_jax_distributed_unittest/test.sh diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh new file mode 100644 index 0000000000..f9e16793a4 --- /dev/null +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -xe + +: ${TE_PATH:=/opt/transformerengine} + +pip install -r $TE_PATH/examples/jax/encoder/requirements.txt + +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index db3aa31951..278a3c8b44 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -20,5 +20,4 @@ pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py From 838345eba4fdd2a169dd9e087d39c30a360e684a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Dec 2024 21:32:41 -0800 Subject: [PATCH 042/239] [common/PyTorch] Add cuDNN SWA (left, 0) + padding + bottom right causal (#1378) * add swa (left,0) + padding + brcm support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * final fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade to FE 1.9-rc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip thd + CP + fused attn tests for cuDNN 9.6+ due to different stats shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- qa/L0_pytorch_unittest/test.sh | 2 +- tests/jax/test_fused_attn.py | 18 +- tests/pytorch/fused_attn/test_fused_attn.py | 186 ++++++++++++------ .../fused_attn/test_fused_attn_with_cp.py | 2 + .../common/fused_attn/fused_attn.cpp | 49 +++-- .../fused_attn_f16_arbitrary_seqlen.cu | 6 +- transformer_engine/pytorch/attention.py | 31 ++- 8 files changed, 195 insertions(+), 101 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 936021bfed..cc5632eda7 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 +Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17307574a9..61dd15d015 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py @@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 759ea893ef..10da7486cf 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -170,8 +170,7 @@ def make_mask( max_seqlen_kv = inv_mask.shape[-1] inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - # In inv_swa_mask and inv_mask 0 is masked out - inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self): return self.num_segments_per_seq + 1 def _check_configs(self): + # TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference + if ( + self.qkv_layout.is_thd() + and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK + and self.window_size is not None + ): + pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.") # TODO(rewang): probably adds this in is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -504,7 +510,13 @@ def generate_random_segment_ids( if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.mask_for_customcall = self.mask + self.mask_for_customcall = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index dea31b5971..588e6e4ecd 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -237,19 +237,18 @@ def test_dot_product_attention( tols = dict(atol=1.5e-2, rtol=1.5e-2) config = model_configs[model] is_mla = config.head_dim_qk != config.head_dim_v + is_mqa_gqa = config.num_heads != config.num_gqa_groups if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" + qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd" else: - qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" + qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") - # Test backend availability - window_size = (-1, -1) - if swa: - window_size = [2, 2] - config.window_size = check_set_window_size(config.attn_mask_type, window_size) + if config.window_size == (-1, -1) and swa: + config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) available_backends, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, @@ -334,16 +333,16 @@ def test_dot_product_attention( is_training, ) - if unfused_attn_supported and fused_attn_supported: - logging.info("[test_dot_product_attention]: unfused attn vs fused attn") - torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) - for i, _ in enumerate(unfused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) + if unfused_attn_supported and fused_attn_supported: + logging.info("[test_dot_product_attention]: unfused attn vs fused attn") + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + for i, _ in enumerate(unfused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) @@ -399,30 +398,41 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), - "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "mask_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), + "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), + "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), + "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), + "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_10_0": ModelConfig( + 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_8_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_10_1": ModelConfig( + 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), } @@ -531,20 +541,28 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_0": ModelConfig( - 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" - ), + "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "swa_6_1": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_2": ModelConfig( + 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_3": ModelConfig( 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" ), } @@ -623,18 +641,57 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), - "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_2_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_3_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_4_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_2": ModelConfig( + 2, + 24, + 24, + 128, + 2048, + 4096, + 0.0, + "padding_causal_bottom_right", + "no_bias", + window_size=(4, 0), + ), } @@ -651,11 +708,13 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): config = model_configs[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True") pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) if get_cudnn_version() >= (9, 3, 0): + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False") # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run pad_between_seqs = False test_dot_product_attention( @@ -695,9 +754,12 @@ def _run_dot_product_attention( ) seqlens_kv = seqlens_q if config.attn_type == "cross": - seqlens_q = torch.randint( - 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" - ) + if config.max_seqlen_q > 1: + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda") seqlens_kv = torch.randint( 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" ) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1007d6aa34..fd8e543adc 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0): + pytest.skip("THD format is not supported for cuDNN 9.6+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9cde765401..32e6d4df8f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset) { flag_m512 = true; } + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging if ( // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && @@ -152,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - ((cudnn_runtime_version >= 8906) && + (cudnn_runtime_version >= 8906 && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && @@ -161,43 +162,67 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - ((cudnn_runtime_version >= 90000) && + (cudnn_runtime_version >= 90000 && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && // mask type + // pre-8.9.6: causal ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - ((cudnn_runtime_version >= 8906) && + // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} + (cudnn_runtime_version >= 8906 && + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - ((cudnn_runtime_version >= 90300) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 9.1: adds thd + {padding, padding_causal} + (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90300 && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90600 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 90600)))) && + cudnn_runtime_version >= 90600))) && // sliding window + // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) && + qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || + // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + (cudnn_runtime_version >= 90600 && + ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index b706eadace..cade624c8d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -71,7 +71,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_bottom_right = false; } bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -451,7 +452,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_bottom_right = false; } bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index be0d176520..9268b9636e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -602,6 +602,12 @@ def get_attention_backend( "Disabling FusedAttention as it does not support context parallelism with MLA" ) use_fused_attention = False + elif cudnn_version >= (9, 6, 0) and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD for" + " cuDNN 9.6+" + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends @@ -618,9 +624,7 @@ def get_attention_backend( # self-attention | | All # cross-attention | | FusedAttention, UnfusedDotProductAttention # causal_bottom_right | None | All - # padding_causal_bottom_right | Same as "padding" | - # self-attention | | All - # cross-attention | | FlashAttention, UnfusedDotProductAttention + # padding_causal_bottom_right | Same as "padding" | All # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": @@ -697,29 +701,16 @@ def get_attention_backend( " for FP8" ) use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd": + elif window_size[1] != 0 or attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " - "with causal mask, no dropout, and qkv_format = bshd/sbhd" - ) - use_fused_attention = False - elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ - "no_mask", - "padding", - "causal_bottom_right", - "padding_causal_bottom_right", - ]: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s for cross-attention", - attn_mask_type, + "with (left, 0) and no dropout" ) use_fused_attention = False - elif "padding" in attn_mask_type: + elif max_seqlen_q > max_seqlen_kv: logger.debug( "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s", - attn_mask_type, + "with s_q > s_kv for cross-attention" ) use_fused_attention = False if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): From c9ea6be92948e1ec553037f1a04900617b9f7f6b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 2 Jan 2025 14:20:59 -0800 Subject: [PATCH 043/239] Update copyright to include 2025 (#1388) Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/blossom-ci.yml | 2 +- .github/workflows/build.yml | 2 +- .github/workflows/deploy_nightly_docs.yml | 2 +- .github/workflows/docs.yml | 2 +- .github/workflows/license.yml | 2 +- .github/workflows/lint.yml | 2 +- .github/workflows/trigger-ci.yml | 2 +- .github/workflows/upload-ci-logs.yml | 2 +- CONTRIBUTING.rst | 2 +- CPPLINT.cfg | 2 +- README.rst | 2 +- benchmarks/attention/benchmark_attention.py | 2 +- build_tools/__init__.py | 2 +- build_tools/build_ext.py | 2 +- build_tools/jax.py | 2 +- build_tools/paddle.py | 2 +- build_tools/pytorch.py | 2 +- build_tools/te_version.py | 2 +- build_tools/utils.py | 2 +- build_tools/wheel_utils/Dockerfile.aarch | 2 +- build_tools/wheel_utils/Dockerfile.x86 | 2 +- build_tools/wheel_utils/build_wheels.sh | 2 +- build_tools/wheel_utils/launch_aarch.sh | 2 +- build_tools/wheel_utils/launch_x86.sh | 2 +- docs/api/c/activation.rst | 2 +- docs/api/c/cast.rst | 2 +- docs/api/c/fused_attn.rst | 2 +- docs/api/c/gemm.rst | 2 +- docs/api/c/index.rst | 2 +- docs/api/c/layer_norm.rst | 2 +- docs/api/c/rmsnorm.rst | 2 +- docs/api/c/softmax.rst | 2 +- docs/api/c/transformer_engine.rst | 2 +- docs/api/c/transpose.rst | 2 +- docs/api/common.rst | 2 +- docs/api/framework.rst | 2 +- docs/api/jax.rst | 2 +- docs/api/paddle.rst | 2 +- docs/api/pytorch.rst | 2 +- docs/conf.py | 2 +- docs/examples/attention/arbitrary_mask_to_post_scale_bias.py | 2 +- docs/examples/attention/example_attention.py | 2 +- docs/examples/quickstart_utils.py | 2 +- docs/examples/te_llama/te_llama.py | 2 +- docs/examples/te_llama/utils.py | 2 +- docs/faq.rst | 2 +- docs/index.rst | 2 +- docs/installation.rst | 2 +- examples/jax/encoder/common.py | 2 +- examples/jax/encoder/test_model_parallel_encoder.py | 2 +- examples/jax/encoder/test_multigpu_encoder.py | 2 +- examples/jax/encoder/test_multiprocessing_encoder.py | 2 +- examples/jax/encoder/test_single_gpu_encoder.py | 2 +- examples/jax/mnist/test_single_gpu_mnist.py | 2 +- examples/paddle/mnist/test_single_gpu_mnist.py | 2 +- examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py | 2 +- examples/pytorch/fsdp/README.md | 2 +- examples/pytorch/fsdp/fsdp.py | 2 +- examples/pytorch/mnist/main.py | 2 +- qa/L0_cppunittest/test.sh | 2 +- qa/L0_jax_distributed_unittest/test.sh | 2 +- qa/L0_jax_lint/test.sh | 2 +- qa/L0_jax_unittest/test.sh | 2 +- qa/L0_jax_wheel/test.sh | 2 +- qa/L0_license/copyright_checker.py | 2 +- qa/L0_license/test.sh | 2 +- qa/L0_paddle_lint/test.sh | 2 +- qa/L0_paddle_unittest/test.sh | 2 +- qa/L0_paddle_wheel/test.sh | 2 +- qa/L0_pytorch_lint/test.sh | 2 +- qa/L0_pytorch_unittest/test.sh | 2 +- qa/L0_pytorch_wheel/test.sh | 2 +- qa/L1_jax_distributed_unittest/test.sh | 2 +- qa/L1_pytorch_distributed_unittest/test.sh | 2 +- qa/L1_pytorch_mcore_integration/test.sh | 2 +- qa/L1_pytorch_onnx_test/test.sh | 2 +- qa/L3_pytorch_FA_versions_test/test.sh | 2 +- qa/L3_pytorch_convergence_test/test.sh | 2 +- qa/format.sh | 2 +- setup.py | 2 +- tests/cpp/CMakeLists.txt | 2 +- tests/cpp/operator/CMakeLists.txt | 2 +- tests/cpp/operator/test_act.cu | 2 +- tests/cpp/operator/test_cast_transpose.cu | 2 +- tests/cpp/operator/test_cast_transpose_dbias.cu | 2 +- tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu | 2 +- tests/cpp/operator/test_cast_transpose_dgeglu.cu | 2 +- tests/cpp/operator/test_causal_softmax.cu | 2 +- tests/cpp/operator/test_multi_cast_transpose.cu | 2 +- tests/cpp/operator/test_multi_padding.cu | 2 +- tests/cpp/operator/test_normalization.cu | 2 +- tests/cpp/operator/test_qdq.cu | 2 +- tests/cpp/operator/test_transpose.cu | 2 +- tests/cpp/test_common.cu | 2 +- tests/cpp/test_common.h | 2 +- tests/cpp/util/CMakeLists.txt | 2 +- tests/cpp/util/test_nvrtc.cpp | 2 +- tests/cpp/util/test_string.cpp | 2 +- tests/jax/conftest.py | 2 +- tests/jax/distributed_test_base.py | 2 +- tests/jax/pytest.ini | 2 +- tests/jax/test_custom_call_compute.py | 2 +- tests/jax/test_distributed_fused_attn.py | 2 +- tests/jax/test_distributed_layernorm.py | 2 +- tests/jax/test_distributed_layernorm_mlp.py | 2 +- tests/jax/test_distributed_softmax.py | 2 +- tests/jax/test_functions.py | 2 +- tests/jax/test_fused_attn.py | 2 +- tests/jax/test_helper.py | 2 +- tests/jax/test_layer.py | 2 +- tests/jax/test_misc.py | 2 +- tests/jax/test_praxis_layers.py | 2 +- tests/jax/test_sanity_import.py | 2 +- tests/jax/test_sharding.py | 2 +- tests/jax/test_softmax.py | 2 +- tests/jax/utils.py | 2 +- tests/paddle/dist_launcher.py | 2 +- tests/paddle/parallel_tests/amax_reduction.py | 2 +- tests/paddle/parallel_tests/attention_tp.py | 2 +- tests/paddle/parallel_tests/group_sharding.py | 2 +- tests/paddle/parallel_tests/layernorm_linear_tp.py | 2 +- tests/paddle/parallel_tests/layernorm_mlp_tp.py | 2 +- tests/paddle/parallel_tests/linear_pp.py | 2 +- tests/paddle/parallel_tests/linear_tp.py | 2 +- tests/paddle/parallel_tests/transformer_tp.py | 2 +- tests/paddle/recompute_tests/recompute_transformer_encoder.py | 2 +- tests/paddle/test_install.py | 2 +- tests/paddle/test_layers.py | 2 +- tests/paddle/test_master_grad.py | 2 +- tests/paddle/test_operators.py | 2 +- tests/paddle/test_parallel.py | 2 +- tests/paddle/test_recompute.py | 2 +- tests/paddle/test_sanity_import.py | 2 +- tests/paddle/utils.py | 2 +- tests/pytorch/custom_ort_ops/CMakeLists.txt | 2 +- tests/pytorch/custom_ort_ops/build.sh | 2 +- tests/pytorch/custom_ort_ops/custom_op_library.cc | 2 +- tests/pytorch/custom_ort_ops/custom_op_library.h | 2 +- tests/pytorch/distributed/print_logs.py | 2 +- tests/pytorch/distributed/run_fsdp2_model.py | 2 +- tests/pytorch/distributed/run_gemm_with_overlap.py | 2 +- tests/pytorch/distributed/run_layer_with_overlap.py | 2 +- tests/pytorch/distributed/run_megatron_lm_gpt.sh | 2 +- tests/pytorch/distributed/run_numerics.py | 2 +- tests/pytorch/distributed/test_comm_gemm_overlap.py | 2 +- tests/pytorch/distributed/test_convergence.py | 2 +- tests/pytorch/distributed/test_fusible_ops.py | 2 +- tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py | 2 +- tests/pytorch/distributed/test_numerics.py | 2 +- tests/pytorch/distributed/test_torch_fsdp2.py | 2 +- tests/pytorch/fused_attn/run_fused_attn_with_cp.py | 2 +- tests/pytorch/fused_attn/test_fused_attn.py | 2 +- tests/pytorch/fused_attn/test_fused_attn_with_cp.py | 2 +- tests/pytorch/test_cuda_graphs.py | 2 +- tests/pytorch/test_deferred_init.py | 2 +- tests/pytorch/test_float8tensor.py | 2 +- tests/pytorch/test_fused_optimizer.py | 2 +- tests/pytorch/test_fused_rope.py | 2 +- tests/pytorch/test_fusible_ops.py | 2 +- tests/pytorch/test_gqa.py | 2 +- tests/pytorch/test_jit.py | 2 +- tests/pytorch/test_multi_tensor.py | 2 +- tests/pytorch/test_numerics.py | 2 +- tests/pytorch/test_onnx_export.py | 2 +- tests/pytorch/test_permutation.py | 2 +- tests/pytorch/test_recipe.py | 2 +- tests/pytorch/test_sanity.py | 2 +- tests/pytorch/test_sanity_import.py | 2 +- tests/pytorch/test_torch_save_load.py | 2 +- tests/pytorch/utils.py | 2 +- transformer_engine/__init__.py | 2 +- transformer_engine/common/CMakeLists.txt | 2 +- transformer_engine/common/__init__.py | 2 +- transformer_engine/common/activation/activation_template.h | 2 +- transformer_engine/common/activation/gelu.cu | 2 +- transformer_engine/common/activation/relu.cu | 2 +- transformer_engine/common/activation/swiglu.cu | 2 +- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 +- .../common/comm_gemm_overlap/userbuffers/ipcsocket.cc | 2 +- .../common/comm_gemm_overlap/userbuffers/ipcsocket.h | 2 +- .../common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 2 +- .../common/comm_gemm_overlap/userbuffers/userbuffers.cu | 2 +- .../common/comm_gemm_overlap/userbuffers/userbuffers.h | 2 +- transformer_engine/common/common.cu | 2 +- transformer_engine/common/common.h | 2 +- transformer_engine/common/cudnn_utils.cpp | 2 +- transformer_engine/common/cudnn_utils.h | 2 +- transformer_engine/common/fused_attn/fused_attn.cpp | 2 +- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 2 +- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.h | 2 +- .../common/fused_attn/fused_attn_f16_max512_seqlen.cu | 2 +- .../common/fused_attn/fused_attn_f16_max512_seqlen.h | 2 +- transformer_engine/common/fused_attn/fused_attn_fp8.cu | 2 +- transformer_engine/common/fused_attn/fused_attn_fp8.h | 2 +- transformer_engine/common/fused_attn/thd_utils.cu | 2 +- transformer_engine/common/fused_attn/thd_utils.h | 2 +- transformer_engine/common/fused_attn/utils.cu | 2 +- transformer_engine/common/fused_attn/utils.h | 2 +- transformer_engine/common/fused_rope/fused_rope.cu | 2 +- .../fused_softmax/scaled_aligned_causal_masked_softmax.cu | 2 +- .../common/fused_softmax/scaled_masked_softmax.cu | 2 +- .../common/fused_softmax/scaled_upper_triang_masked_softmax.cu | 2 +- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- .../common/include/transformer_engine/activation.h | 2 +- transformer_engine/common/include/transformer_engine/cast.h | 2 +- .../common/include/transformer_engine/cast_transpose_noop.h | 2 +- .../common/include/transformer_engine/comm_gemm_overlap.h | 2 +- transformer_engine/common/include/transformer_engine/cudnn.h | 2 +- .../common/include/transformer_engine/fused_attn.h | 2 +- .../common/include/transformer_engine/fused_rope.h | 2 +- transformer_engine/common/include/transformer_engine/gemm.h | 2 +- .../common/include/transformer_engine/normalization.h | 2 +- transformer_engine/common/include/transformer_engine/padding.h | 2 +- .../common/include/transformer_engine/permutation.h | 2 +- transformer_engine/common/include/transformer_engine/recipe.h | 2 +- transformer_engine/common/include/transformer_engine/softmax.h | 2 +- .../common/include/transformer_engine/transformer_engine.h | 2 +- .../common/include/transformer_engine/transpose.h | 2 +- transformer_engine/common/normalization/common.cpp | 2 +- transformer_engine/common/normalization/common.h | 2 +- transformer_engine/common/normalization/kernel_traits.h | 2 +- transformer_engine/common/normalization/layernorm/ln_api.cpp | 2 +- .../common/normalization/layernorm/ln_bwd_kernels.cuh | 2 +- .../common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu | 2 +- .../common/normalization/layernorm/ln_fwd_cuda_kernel.cu | 2 +- .../common/normalization/layernorm/ln_fwd_kernels.cuh | 2 +- transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp | 2 +- .../common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh | 2 +- .../normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 2 +- .../common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 2 +- .../common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh | 2 +- transformer_engine/common/nvtx.h | 2 +- transformer_engine/common/permutation/permutation.cu | 2 +- transformer_engine/common/recipe/__init__.py | 2 +- transformer_engine/common/recipe/delayed_scaling.cu | 2 +- transformer_engine/common/transformer_engine.cpp | 2 +- transformer_engine/common/transpose/cast_transpose.cu | 2 +- transformer_engine/common/transpose/cast_transpose_fusion.cu | 2 +- transformer_engine/common/transpose/multi_cast_transpose.cu | 2 +- transformer_engine/common/transpose/rtc/cast_transpose.cu | 2 +- .../common/transpose/rtc/cast_transpose_fusion.cu | 2 +- transformer_engine/common/transpose/rtc/transpose.cu | 2 +- transformer_engine/common/transpose/transpose.cu | 2 +- transformer_engine/common/transpose/transpose_fusion.cu | 2 +- transformer_engine/common/util/cast.cu | 2 +- transformer_engine/common/util/cuda_driver.cpp | 2 +- transformer_engine/common/util/cuda_driver.h | 2 +- transformer_engine/common/util/cuda_runtime.cpp | 2 +- transformer_engine/common/util/cuda_runtime.h | 2 +- transformer_engine/common/util/logging.h | 2 +- transformer_engine/common/util/math.h | 2 +- transformer_engine/common/util/padding.cu | 2 +- transformer_engine/common/util/pybind_helper.h | 2 +- transformer_engine/common/util/rtc.cpp | 2 +- transformer_engine/common/util/rtc.h | 2 +- transformer_engine/common/util/string.h | 2 +- transformer_engine/common/util/string_header.h.in | 2 +- transformer_engine/common/util/system.cpp | 2 +- transformer_engine/common/util/system.h | 2 +- transformer_engine/common/util/vectorized_pointwise.h | 2 +- transformer_engine/common/utils.cuh | 2 +- transformer_engine/common/utils.py | 2 +- transformer_engine/jax/__init__.py | 2 +- transformer_engine/jax/attention.py | 2 +- transformer_engine/jax/cpp_extensions/__init__.py | 2 +- transformer_engine/jax/cpp_extensions/activation.py | 2 +- transformer_engine/jax/cpp_extensions/attention.py | 2 +- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/cpp_extensions/custom_call.py | 2 +- transformer_engine/jax/cpp_extensions/misc.py | 2 +- transformer_engine/jax/cpp_extensions/normalization.py | 2 +- transformer_engine/jax/cpp_extensions/quantization.py | 2 +- transformer_engine/jax/cpp_extensions/softmax.py | 2 +- transformer_engine/jax/cpp_extensions/transpose.py | 2 +- transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/activation.cpp | 2 +- transformer_engine/jax/csrc/extensions/attention.cpp | 2 +- transformer_engine/jax/csrc/extensions/cudnn.cpp | 2 +- transformer_engine/jax/csrc/extensions/ffi.cpp | 2 +- transformer_engine/jax/csrc/extensions/ffi.h | 2 +- transformer_engine/jax/csrc/extensions/misc.cpp | 2 +- transformer_engine/jax/csrc/extensions/misc.h | 2 +- transformer_engine/jax/csrc/extensions/normalization.cpp | 2 +- transformer_engine/jax/csrc/extensions/packing.cpp | 2 +- transformer_engine/jax/csrc/extensions/pybind.cpp | 2 +- transformer_engine/jax/csrc/extensions/quantization.cpp | 2 +- transformer_engine/jax/csrc/extensions/softmax.cpp | 2 +- transformer_engine/jax/csrc/extensions/transpose.cpp | 2 +- transformer_engine/jax/csrc/utils.cu | 2 +- transformer_engine/jax/csrc/utils.h | 2 +- transformer_engine/jax/dot.py | 2 +- transformer_engine/jax/flax/__init__.py | 2 +- transformer_engine/jax/flax/module.py | 2 +- transformer_engine/jax/flax/transformer.py | 2 +- transformer_engine/jax/fp8.py | 2 +- transformer_engine/jax/layernorm.py | 2 +- transformer_engine/jax/layernorm_mlp.py | 2 +- transformer_engine/jax/praxis/__init__.py | 2 +- transformer_engine/jax/praxis/module.py | 2 +- transformer_engine/jax/praxis/transformer.py | 2 +- transformer_engine/jax/setup.py | 2 +- transformer_engine/jax/sharding.py | 2 +- transformer_engine/jax/softmax.py | 2 +- transformer_engine/paddle/__init__.py | 2 +- transformer_engine/paddle/constants.py | 2 +- transformer_engine/paddle/cpp_extensions.py | 2 +- transformer_engine/paddle/csrc/common.cpp | 2 +- transformer_engine/paddle/csrc/common.h | 2 +- transformer_engine/paddle/csrc/custom_ops.cu | 2 +- transformer_engine/paddle/csrc/extensions.cpp | 2 +- transformer_engine/paddle/distributed.py | 2 +- transformer_engine/paddle/fp8.py | 2 +- transformer_engine/paddle/fp8_buffer.py | 2 +- transformer_engine/paddle/layer/__init__.py | 2 +- transformer_engine/paddle/layer/attention.py | 2 +- transformer_engine/paddle/layer/base.py | 2 +- transformer_engine/paddle/layer/layernorm.py | 2 +- transformer_engine/paddle/layer/layernorm_linear.py | 2 +- transformer_engine/paddle/layer/layernorm_mlp.py | 2 +- transformer_engine/paddle/layer/linear.py | 2 +- transformer_engine/paddle/layer/rmsnorm.py | 2 +- transformer_engine/paddle/layer/softmax.py | 2 +- transformer_engine/paddle/layer/transformer.py | 2 +- transformer_engine/paddle/profile.py | 2 +- transformer_engine/paddle/recompute.py | 2 +- transformer_engine/paddle/setup.py | 2 +- transformer_engine/paddle/utils.py | 2 +- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/attention.py | 2 +- transformer_engine/pytorch/constants.py | 2 +- transformer_engine/pytorch/cpp_extensions/__init__.py | 2 +- transformer_engine/pytorch/cpp_extensions/_common.py | 2 +- transformer_engine/pytorch/cpp_extensions/activation.py | 2 +- transformer_engine/pytorch/cpp_extensions/cast.py | 2 +- transformer_engine/pytorch/cpp_extensions/fused_attn.py | 2 +- transformer_engine/pytorch/cpp_extensions/gemm.py | 2 +- transformer_engine/pytorch/cpp_extensions/normalization.py | 2 +- transformer_engine/pytorch/cpp_extensions/padding.py | 2 +- transformer_engine/pytorch/cpp_extensions/transpose.py | 2 +- transformer_engine/pytorch/cpu_offload.py | 2 +- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- transformer_engine/pytorch/csrc/extensions/activation.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/apply_rope.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/attention.cu | 2 +- transformer_engine/pytorch/csrc/extensions/cast.cpp | 2 +- .../pytorch/csrc/extensions/comm_gemm_overlap.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/misc.cpp | 2 +- .../pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu | 2 +- .../csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu | 2 +- .../csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu | 2 +- .../csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu | 2 +- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/padding.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/permutation.cu | 2 +- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/recipe.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/softmax.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/transpose.cpp | 2 +- transformer_engine/pytorch/csrc/multi_tensor_apply.cuh | 2 +- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 2 +- transformer_engine/pytorch/csrc/type_shim.h | 2 +- transformer_engine/pytorch/distributed.py | 2 +- transformer_engine/pytorch/export.py | 2 +- transformer_engine/pytorch/float8_tensor.py | 2 +- transformer_engine/pytorch/fp8.py | 2 +- transformer_engine/pytorch/graph.py | 2 +- transformer_engine/pytorch/jit.py | 2 +- transformer_engine/pytorch/module/__init__.py | 2 +- transformer_engine/pytorch/module/_common.py | 2 +- transformer_engine/pytorch/module/base.py | 2 +- transformer_engine/pytorch/module/fp8_padding.py | 2 +- transformer_engine/pytorch/module/fp8_unpadding.py | 2 +- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/layernorm.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/module/rmsnorm.py | 2 +- transformer_engine/pytorch/numerics_debug.py | 2 +- transformer_engine/pytorch/ops/__init__.py | 2 +- transformer_engine/pytorch/ops/_common.py | 2 +- transformer_engine/pytorch/ops/basic/__init__.py | 2 +- transformer_engine/pytorch/ops/basic/activation.py | 2 +- transformer_engine/pytorch/ops/basic/add_in_place.py | 2 +- transformer_engine/pytorch/ops/basic/all_gather.py | 2 +- transformer_engine/pytorch/ops/basic/all_reduce.py | 2 +- transformer_engine/pytorch/ops/basic/basic_linear.py | 2 +- transformer_engine/pytorch/ops/basic/bias.py | 2 +- transformer_engine/pytorch/ops/basic/identity.py | 2 +- transformer_engine/pytorch/ops/basic/layer_norm.py | 2 +- transformer_engine/pytorch/ops/basic/make_extra_output.py | 2 +- transformer_engine/pytorch/ops/basic/quantize.py | 2 +- transformer_engine/pytorch/ops/basic/reduce_scatter.py | 2 +- transformer_engine/pytorch/ops/basic/reshape.py | 2 +- transformer_engine/pytorch/ops/basic/rmsnorm.py | 2 +- transformer_engine/pytorch/ops/fused/__init__.py | 2 +- transformer_engine/pytorch/ops/fused/backward_linear_add.py | 2 +- .../pytorch/ops/fused/forward_linear_bias_activation.py | 2 +- transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py | 2 +- .../pytorch/ops/fused/userbuffers_backward_linear.py | 2 +- .../pytorch/ops/fused/userbuffers_forward_linear.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 2 +- transformer_engine/pytorch/ops/linear.py | 2 +- transformer_engine/pytorch/ops/op.py | 2 +- transformer_engine/pytorch/ops/sequential.py | 2 +- transformer_engine/pytorch/optimizers/__init__.py | 2 +- transformer_engine/pytorch/optimizers/fused_adam.py | 2 +- transformer_engine/pytorch/optimizers/fused_sgd.py | 2 +- transformer_engine/pytorch/optimizers/multi_tensor_apply.py | 2 +- transformer_engine/pytorch/permutation.py | 2 +- transformer_engine/pytorch/setup.py | 2 +- transformer_engine/pytorch/softmax.py | 2 +- transformer_engine/pytorch/te_onnx_extensions.py | 2 +- transformer_engine/pytorch/tensor/__init__.py | 2 +- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/tensor/quantized_tensor.py | 2 +- transformer_engine/pytorch/transformer.py | 2 +- transformer_engine/pytorch/utils.py | 2 +- 421 files changed, 421 insertions(+), 421 deletions(-) diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 260adfc6d3..1402cc091a 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b5b262baff..964e71fa8c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index fc5e27d0a4..6470eee838 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b6fadba1bd..3c4229a888 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/license.yml b/.github/workflows/license.yml index f789a83d1a..d70c7def61 100644 --- a/.github/workflows/license.yml +++ b/.github/workflows/license.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d2bd865a8f..f98fc9aa3a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 86d22b7944..cef039f976 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/upload-ci-logs.yml b/.github/workflows/upload-ci-logs.yml index b3be2f5c89..c9c7e4ef4d 100644 --- a/.github/workflows/upload-ci-logs.yml +++ b/.github/workflows/upload-ci-logs.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 95767b742f..d92fd95675 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/CPPLINT.cfg b/CPPLINT.cfg index e42ec720b1..ecfbbf3d0b 100644 --- a/CPPLINT.cfg +++ b/CPPLINT.cfg @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/README.rst b/README.rst index 6cc7eeae8a..3f4d9bd4a3 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index cff7c65fbc..dafafdff47 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/__init__.py b/build_tools/__init__.py index 9bcbd954eb..7669e4cfa6 100644 --- a/build_tools/__init__.py +++ b/build_tools/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index af11ada34c..5744439c1b 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/jax.py b/build_tools/jax.py index f829230f50..7e0652c629 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/paddle.py b/build_tools/paddle.py index a68d73956e..f0fcdb8f25 100644 --- a/build_tools/paddle.py +++ b/build_tools/paddle.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 575b7bee79..f060e99dff 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/te_version.py b/build_tools/te_version.py index b40fb26014..0aee63f647 100644 --- a/build_tools/te_version.py +++ b/build_tools/te_version.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/utils.py b/build_tools/utils.py index d846b87f22..f2a4200685 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index 7d839958cb..223c4a7f1c 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 7dedf2a761..26122eed9b 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 7682a2b6aa..ceebe626f4 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/launch_aarch.sh b/build_tools/wheel_utils/launch_aarch.sh index 9a8d796119..04e3cd6916 100644 --- a/build_tools/wheel_utils/launch_aarch.sh +++ b/build_tools/wheel_utils/launch_aarch.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/launch_x86.sh b/build_tools/wheel_utils/launch_x86.sh index 7b5649a642..b0d20be3f4 100644 --- a/build_tools/wheel_utils/launch_x86.sh +++ b/build_tools/wheel_utils/launch_x86.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/api/c/activation.rst b/docs/api/c/activation.rst index 1790121236..5b50aa513d 100644 --- a/docs/api/c/activation.rst +++ b/docs/api/c/activation.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/cast.rst b/docs/api/c/cast.rst index ef98441812..2ae05a8456 100644 --- a/docs/api/c/cast.rst +++ b/docs/api/c/cast.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/fused_attn.rst b/docs/api/c/fused_attn.rst index a0b6255ebe..6db67f26fe 100644 --- a/docs/api/c/fused_attn.rst +++ b/docs/api/c/fused_attn.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/gemm.rst b/docs/api/c/gemm.rst index e7a14cab97..711733fc4c 100644 --- a/docs/api/c/gemm.rst +++ b/docs/api/c/gemm.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index ae0b6ddfa1..d33e5ab607 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/layer_norm.rst b/docs/api/c/layer_norm.rst index 47c0585a42..3ac1c6842d 100644 --- a/docs/api/c/layer_norm.rst +++ b/docs/api/c/layer_norm.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/rmsnorm.rst b/docs/api/c/rmsnorm.rst index fba3b97c57..d6f378cebc 100644 --- a/docs/api/c/rmsnorm.rst +++ b/docs/api/c/rmsnorm.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/softmax.rst b/docs/api/c/softmax.rst index 69875d603c..55dc5d47de 100644 --- a/docs/api/c/softmax.rst +++ b/docs/api/c/softmax.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/transformer_engine.rst b/docs/api/c/transformer_engine.rst index ec474592c3..b5fd95e005 100644 --- a/docs/api/c/transformer_engine.rst +++ b/docs/api/c/transformer_engine.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/transpose.rst b/docs/api/c/transpose.rst index d839f3d3b1..9a3ba9e48b 100644 --- a/docs/api/c/transpose.rst +++ b/docs/api/c/transpose.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/common.rst b/docs/api/common.rst index 40afd88ff3..85201aee5d 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/framework.rst b/docs/api/framework.rst index 88785f941a..acd54fe3b1 100644 --- a/docs/api/framework.rst +++ b/docs/api/framework.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/jax.rst b/docs/api/jax.rst index c7701bd699..d72af37ec5 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst index ad23031f58..3b3ecf55c6 100644 --- a/docs/api/paddle.rst +++ b/docs/api/paddle.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index ba4e7db352..986d79808c 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/conf.py b/docs/conf.py index 7d2d4ea7b9..4083bfd242 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index 85ce01079c..e9eec14d99 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 15022005a1..2c32e8b5f7 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index 0582efd52e..ead95d3bad 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 4413bdfd00..5a40a62da7 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 1aebe13afb..66f05701f5 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/faq.rst b/docs/faq.rst index 50b3a7481e..2f9cbd2720 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/index.rst b/docs/index.rst index 38e095c239..cd9ce41cf5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/installation.rst b/docs/installation.rst index 9ac0ddf841..fae01c64fa 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index dcbfafc467..c79fa45239 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Shared functions for the encoder tests""" diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index bafd9bd2fb..918dfd8238 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training on multi-GPU with tesnor parallelism""" diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index a4a19b43c2..c0325d3e28 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training on multi-GPU with data parallelism""" diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index f54deff69c..ff6fd4d167 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index ac71fe4c0e..b2439278ea 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training on single GPU""" diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index b251bb72ca..54ecadeee8 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """MNIST training on single GPU""" diff --git a/examples/paddle/mnist/test_single_gpu_mnist.py b/examples/paddle/mnist/test_single_gpu_mnist.py index de5c9e9b6c..15e81646ec 100644 --- a/examples/paddle/mnist/test_single_gpu_mnist.py +++ b/examples/paddle/mnist/test_single_gpu_mnist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """MNIST example of Transformer Engine Paddle""" diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index ab6b656be9..d94c352401 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/examples/pytorch/fsdp/README.md b/examples/pytorch/fsdp/README.md index 5ea1225fa1..d62f68bbda 100644 --- a/examples/pytorch/fsdp/README.md +++ b/examples/pytorch/fsdp/README.md @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index cf0a75c336..622228536c 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index 2a003f0a0d..ff9e2f0785 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index d68a3a0f41..b6a6d2d6e4 100644 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index f9e16793a4..f1d1c06d38 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh index 7bc84eef51..72a5c3828e 100644 --- a/qa/L0_jax_lint/test.sh +++ b/qa/L0_jax_lint/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 278a3c8b44..6eff047721 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 2c3b832933..71c1ad5b23 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_license/copyright_checker.py b/qa/L0_license/copyright_checker.py index 46a3a6d4fe..bfd1973033 100644 --- a/qa/L0_license/copyright_checker.py +++ b/qa/L0_license/copyright_checker.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding: utf-8 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_license/test.sh b/qa/L0_license/test.sh index 8b9c86b39b..4342e22c23 100644 --- a/qa/L0_license/test.sh +++ b/qa/L0_license/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh index 5c5379554f..1c26bd265b 100644 --- a/qa/L0_paddle_lint/test.sh +++ b/qa/L0_paddle_lint/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_paddle_unittest/test.sh b/qa/L0_paddle_unittest/test.sh index 1038923b5a..9312f22ba4 100644 --- a/qa/L0_paddle_unittest/test.sh +++ b/qa/L0_paddle_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index 00653877b8..5116bdb5cf 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index ac517976c7..13cf07cafc 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 61dd15d015..793fa47259 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index fd8457c44b..f650a30f01 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index eb09df1a84..deb0f93cec 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 4e52153db9..ee7c28ca5f 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh index b0aba17ef5..cfc6446909 100644 --- a/qa/L1_pytorch_mcore_integration/test.sh +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh index 5a01468064..8e4ef03b8e 100644 --- a/qa/L1_pytorch_onnx_test/test.sh +++ b/qa/L1_pytorch_onnx_test/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 6c23e39a48..e63ba358a5 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/L3_pytorch_convergence_test/test.sh b/qa/L3_pytorch_convergence_test/test.sh index fca621f279..110e26cc8a 100644 --- a/qa/L3_pytorch_convergence_test/test.sh +++ b/qa/L3_pytorch_convergence_test/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/qa/format.sh b/qa/format.sh index d38b832263..caaa0ba416 100644 --- a/qa/format.sh +++ b/qa/format.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/setup.py b/setup.py index 3bb2fe6b95..16e988aa88 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 3bef457c43..d8c8d99fac 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index ab6b6a5316..178dc5e8dd 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index 7d03e41271..cec997d078 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 39a6614179..05fcafb0b1 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 651508c871..72d890f8e9 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 38ac955bc9..d3ba31fa53 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index b1881b2a96..03cec4e658 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu index 640434674b..5401b03296 100644 --- a/tests/cpp/operator/test_causal_softmax.cu +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index e7fb183217..e9f420e5b1 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu index e9e42725fe..23c824e857 100644 --- a/tests/cpp/operator/test_multi_padding.cu +++ b/tests/cpp/operator/test_multi_padding.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index bd6ee96af8..58152864eb 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index 565e3986e6..76f049360a 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/operator/test_transpose.cu b/tests/cpp/operator/test_transpose.cu index 844f6801f1..0852ddf7c3 100644 --- a/tests/cpp/operator/test_transpose.cu +++ b/tests/cpp/operator/test_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index b90ea183cb..84cc11673b 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index a6181256d9..4598a7b021 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index d93be956b0..ffa05f0d66 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/util/test_nvrtc.cpp b/tests/cpp/util/test_nvrtc.cpp index 03982deb73..e885140ce1 100644 --- a/tests/cpp/util/test_nvrtc.cpp +++ b/tests/cpp/util/test_nvrtc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/util/test_string.cpp b/tests/cpp/util/test_string.cpp index 14c1cc11f3..a2e8bc1410 100644 --- a/tests/cpp/util/test_string.cpp +++ b/tests/cpp/util/test_string.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 5bb86c6081..920f9dc62e 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """conftest for tests/jax""" diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index bbd54ecce5..c2d7039a53 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import operator diff --git a/tests/jax/pytest.ini b/tests/jax/pytest.ini index 2cbbe2ac67..4b1f68aa77 100644 --- a/tests/jax/pytest.ini +++ b/tests/jax/pytest.ini @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20b16c2809..4e4be7569f 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 1538062975..ccbdf9407c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index f0dd56feaa..cc59ecfb34 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 38f7ec0d49..87a5145c65 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import pytest diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0ed6b84fd5..8f48bc77dd 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_functions.py b/tests/jax/test_functions.py index d6da307fd3..48a2fb4f88 100644 --- a/tests/jax/test_functions.py +++ b/tests/jax/test_functions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 10da7486cf..01fc2b3e21 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tests for fused attention""" diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 3a0add0a38..e906a37414 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 55c09b4562..e6ad8ce20c 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test transformer_engine.jax.flax.TransformerLayer""" diff --git a/tests/jax/test_misc.py b/tests/jax/test_misc.py index 67145daf63..6db492921d 100644 --- a/tests/jax/test_misc.py +++ b/tests/jax/test_misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 8ac8ecbe79..935eb290e4 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_sanity_import.py b/tests/jax/test_sanity_import.py index f47c2eb411..5e1bca2c9c 100644 --- a/tests/jax/test_sanity_import.py +++ b/tests/jax/test_sanity_import.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index 4581cdc39e..0d50b73451 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 49e32e503c..8cc8448979 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tests for the softmax primitives""" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 242bafa5e2..3ff879e68c 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Utility for the TE layer tests""" diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py index 8c417b1930..f262f1a1d4 100644 --- a/tests/paddle/dist_launcher.py +++ b/tests/paddle/dist_launcher.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Helper functions to launch distributed tests""" diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py index c4605f121e..3e0a6d2bac 100644 --- a/tests/paddle/parallel_tests/amax_reduction.py +++ b/tests/paddle/parallel_tests/amax_reduction.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for Linear layer in tensor parallel""" diff --git a/tests/paddle/parallel_tests/attention_tp.py b/tests/paddle/parallel_tests/attention_tp.py index e145f20b39..c0ffa288ee 100644 --- a/tests/paddle/parallel_tests/attention_tp.py +++ b/tests/paddle/parallel_tests/attention_tp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for Transformer layer in tensor parallel""" diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py index 11060be38e..21d08a8ef3 100644 --- a/tests/paddle/parallel_tests/group_sharding.py +++ b/tests/paddle/parallel_tests/group_sharding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for group sharding""" diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py index 02295a71da..96070a03c5 100644 --- a/tests/paddle/parallel_tests/layernorm_linear_tp.py +++ b/tests/paddle/parallel_tests/layernorm_linear_tp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for LayerNormLinear layer in tensor parallel""" diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py index f23cfb9e3f..9ec09c7e7a 100644 --- a/tests/paddle/parallel_tests/layernorm_mlp_tp.py +++ b/tests/paddle/parallel_tests/layernorm_mlp_tp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for LayerNormMLP layer in tensor parallel""" diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py index 0e7e90611e..68271e52e7 100644 --- a/tests/paddle/parallel_tests/linear_pp.py +++ b/tests/paddle/parallel_tests/linear_pp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for Linear layer in pipeline parallel""" diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py index 4a49474a37..1a42d6c621 100644 --- a/tests/paddle/parallel_tests/linear_tp.py +++ b/tests/paddle/parallel_tests/linear_tp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for Linear layer in tensor parallel""" diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py index 5506be042f..5fc3e7ddf3 100644 --- a/tests/paddle/parallel_tests/transformer_tp.py +++ b/tests/paddle/parallel_tests/transformer_tp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Unittest for Transformer layer in tensor parallel""" diff --git a/tests/paddle/recompute_tests/recompute_transformer_encoder.py b/tests/paddle/recompute_tests/recompute_transformer_encoder.py index 56d0c24535..e753f750c5 100644 --- a/tests/paddle/recompute_tests/recompute_transformer_encoder.py +++ b/tests/paddle/recompute_tests/recompute_transformer_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test TransformerLayer encoder recompute""" diff --git a/tests/paddle/test_install.py b/tests/paddle/test_install.py index 686771ec09..1c317584ed 100644 --- a/tests/paddle/test_install.py +++ b/tests/paddle/test_install.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test basic installation of Paddle extensions""" diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index b519fc0a0f..fbd6c61ad7 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test TE Paddle Layer-level APIs""" diff --git a/tests/paddle/test_master_grad.py b/tests/paddle/test_master_grad.py index 4e029cf8dd..c896a7871c 100644 --- a/tests/paddle/test_master_grad.py +++ b/tests/paddle/test_master_grad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test TransformerLayer encoder main_grad""" diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index b3b8560775..d9b1fa5cd1 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test TE operators""" diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py index f07d56d44b..82f970b2c8 100644 --- a/tests/paddle/test_parallel.py +++ b/tests/paddle/test_parallel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test TE Paddle Parallel""" diff --git a/tests/paddle/test_recompute.py b/tests/paddle/test_recompute.py index 02dddad210..59079b0d1d 100644 --- a/tests/paddle/test_recompute.py +++ b/tests/paddle/test_recompute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test TE Paddle Recompute""" diff --git a/tests/paddle/test_sanity_import.py b/tests/paddle/test_sanity_import.py index 9b38d543da..0390f2f6a0 100644 --- a/tests/paddle/test_sanity_import.py +++ b/tests/paddle/test_sanity_import.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py index 572af66ff9..b0a8d0d80b 100644 --- a/tests/paddle/utils.py +++ b/tests/paddle/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Utils for testing""" diff --git a/tests/pytorch/custom_ort_ops/CMakeLists.txt b/tests/pytorch/custom_ort_ops/CMakeLists.txt index 90fb3624c1..d3e95bd4bc 100644 --- a/tests/pytorch/custom_ort_ops/CMakeLists.txt +++ b/tests/pytorch/custom_ort_ops/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/custom_ort_ops/build.sh b/tests/pytorch/custom_ort_ops/build.sh index 989da2f4ef..01502ba6fb 100644 --- a/tests/pytorch/custom_ort_ops/build.sh +++ b/tests/pytorch/custom_ort_ops/build.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.cc b/tests/pytorch/custom_ort_ops/custom_op_library.cc index f46e897152..c7b94ff700 100755 --- a/tests/pytorch/custom_ort_ops/custom_op_library.cc +++ b/tests/pytorch/custom_ort_ops/custom_op_library.cc @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.h b/tests/pytorch/custom_ort_ops/custom_op_library.h index 7e4b8256bc..747e6c5083 100755 --- a/tests/pytorch/custom_ort_ops/custom_op_library.h +++ b/tests/pytorch/custom_ort_ops/custom_op_library.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/pytorch/distributed/print_logs.py b/tests/pytorch/distributed/print_logs.py index 6c25db4945..9d3cb3838f 100644 --- a/tests/pytorch/distributed/print_logs.py +++ b/tests/pytorch/distributed/print_logs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 0f00a6717b..e32f64cf1c 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index b00b8cc042..4f170e3f84 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e32a7ccb12..e49174c24f 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/run_megatron_lm_gpt.sh b/tests/pytorch/distributed/run_megatron_lm_gpt.sh index 855f0c3030..356399662c 100755 --- a/tests/pytorch/distributed/run_megatron_lm_gpt.sh +++ b/tests/pytorch/distributed/run_megatron_lm_gpt.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 5d2828454c..64f36051c6 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index f81fbae1fe..240e396534 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os diff --git a/tests/pytorch/distributed/test_convergence.py b/tests/pytorch/distributed/test_convergence.py index 5a267cb25e..2d468cd301 100644 --- a/tests/pytorch/distributed/test_convergence.py +++ b/tests/pytorch/distributed/test_convergence.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index d8a018761b..598859b826 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index ead121f314..b61f519c99 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index d0b445a505..1a6191f06c 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 3c9197c322..02a85f0ac4 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 3ddfab055c..7a4d953840 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 588e6e4ecd..d546118ffb 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index fd8e543adc..73994e1873 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 010050baea..d92884eaa2 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py index 0469a01c5f..7d6d523622 100644 --- a/tests/pytorch/test_deferred_init.py +++ b/tests/pytorch/test_deferred_init.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index a25ffa773c..96b4ab4967 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 4d4eb38342..be01f2c011 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 81c4973756..7ad2c93aa5 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import math diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index fd2832c1d4..e2f712cce8 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_gqa.py b/tests/pytorch/test_gqa.py index 9f9098891f..3ef4806182 100644 --- a/tests/pytorch/test_gqa.py +++ b/tests/pytorch/test_gqa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_jit.py b/tests/pytorch/test_jit.py index 7d69e03712..ec62fba9d9 100644 --- a/tests/pytorch/test_jit.py +++ b/tests/pytorch/test_jit.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 216b200e09..ecc06c3ace 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c237dbaeb6..e9b6303933 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 6a463b556a..46e888462a 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index ed25b96955..2fd8e49114 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 0c2118718c..646dea552e 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 32d517460a..daf8506593 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_sanity_import.py b/tests/pytorch/test_sanity_import.py index 954d807b7d..5657cf0d85 100644 --- a/tests/pytorch/test_sanity_import.py +++ b/tests/pytorch/test_sanity_import.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index be77109cb7..46ce33becc 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index a8b181a187..450c24da33 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index b18ded9775..d97d9653e6 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 84fc567cd3..3efe116105 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 4bcd1f8e27..efcd4dc0b0 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 6184e235bd..ddb786bd3a 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index f9cd7b845a..cb38b351e9 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index c18d018a8e..7653991819 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index c745ffeeb4..5a0e0ead84 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index c6f0f870ff..003ea9588c 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc index 2fc6ffbdf9..71ea00de3a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h index 979df384a8..aa6021a190 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 6f3eef3d28..c3453aeffe 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 91667958e7..b2cd71f76b 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 75655ef691..ee808b7f9a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 4e95fc24de..01b940f06a 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 8830c8875d..d47ce472e5 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index 35e2d11799..80d2707315 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h index d2827b637a..eb19b9ddb2 100644 --- a/transformer_engine/common/cudnn_utils.h +++ b/transformer_engine/common/cudnn_utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 32e6d4df8f..5d3e1d6097 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index cade624c8d..20467af663 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 3a1216f891..687928d080 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 9341ebf5f9..08e0642b29 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index a5b25f3279..171fe846ce 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f8fe458219..0044a94b2f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 55830d3cda..3daf45d162 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/thd_utils.cu b/transformer_engine/common/fused_attn/thd_utils.cu index a1e353be71..17c732c530 100644 --- a/transformer_engine/common/fused_attn/thd_utils.cu +++ b/transformer_engine/common/fused_attn/thd_utils.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/common/fused_attn/thd_utils.h index c9a62727e6..91f5f7bac1 100644 --- a/transformer_engine/common/fused_attn/thd_utils.h +++ b/transformer_engine/common/fused_attn/thd_utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a053c55fb6..8e2d831413 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f790d3b567..ed498049c3 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 26f104d3ed..7f35ddd70b 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index 841edcf043..4628b37949 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 08fd32af9c..2d31f82bab 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 8571887ee6..7d97680ec3 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 593ec086d7..ef7cdc0af9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 656c647fd4..53a66c25b5 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 32f16922b9..88a7dec251 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index 9043162bcb..ea3bdcd14e 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 1d5d192a39..8e0d017a0d 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/cudnn.h b/transformer_engine/common/include/transformer_engine/cudnn.h index c5e4bc23a9..70acead631 100644 --- a/transformer_engine/common/include/transformer_engine/cudnn.h +++ b/transformer_engine/common/include/transformer_engine/cudnn.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ae08f2a4aa..b9c8db1598 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index b7b9b93881..41a0e3bc76 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 1cdbfd2eb5..2cb99f3d28 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index de9644792b..8c34540e34 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h index a419b38234..4258463b1b 100644 --- a/transformer_engine/common/include/transformer_engine/padding.h +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index c6263bf87e..195075c975 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 61b1f231b8..a076a4e89a 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/softmax.h b/transformer_engine/common/include/transformer_engine/softmax.h index 6a6fc15fa6..9f1c423172 100644 --- a/transformer_engine/common/include/transformer_engine/softmax.h +++ b/transformer_engine/common/include/transformer_engine/softmax.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index d302518235..99b3508362 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index ef3d344b05..781f171cd8 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 5b6beb66b1..89e2e9feec 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index d1d56d5cc9..f366ba26db 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h index 0f8fea3f0b..78d9212de6 100644 --- a/transformer_engine/common/normalization/kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 2c65131b9d..a412bae745 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index 44078a040b..b68e79cd98 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index d6e15dfc30..f63edfb644 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index e7fe7a201b..9336abc26c 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 3ec5543c3a..eb2f62b4b0 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index f6e36ae3c9..dd4c8e580d 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index 223ac7fd79..5d8a5b765a 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 309075c1ec..fb5741b35b 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 73634fc2dd..25bed95dc5 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index 5965ffdc5d..c631847395 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/nvtx.h b/transformer_engine/common/nvtx.h index 4625e0ab9d..ada7a59092 100644 --- a/transformer_engine/common/nvtx.h +++ b/transformer_engine/common/nvtx.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 2b894fbfdc..7e9e2a97f7 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ba276ad406..2c9944439d 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index fcace6ac3d..b16bad9e6a 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a3b49f9fa..11e0e319ed 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index dd45d0a668..b49c61195e 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index a8361d57ea..ed919c8b94 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 4026016519..16894ad4b5 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu index 07244a42e9..952d70f38b 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index 4ba1cb4c69..2424247bbe 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/rtc/transpose.cu b/transformer_engine/common/transpose/rtc/transpose.cu index 09758698f6..6d05c68106 100644 --- a/transformer_engine/common/transpose/rtc/transpose.cu +++ b/transformer_engine/common/transpose/rtc/transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 5e8ef80ae4..339748ead0 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index c032371940..39c702dade 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index dd03afd21b..e0c92c22cb 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 797a11c43c..8605447c61 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 9dc1114580..dcad582210 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 8d2e852988..cc9a659b5b 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index ea1ba84772..33c2aea8d4 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 7972db3162..10a4ec28dc 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 26204cddb8..2d425d6753 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 017d2e6a56..e90d2de558 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..97c5bee2b1 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index c03654bfc5..bc286dd621 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 2c79d038b2..820b16c206 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/string.h b/transformer_engine/common/util/string.h index 3b0db02809..0064144102 100644 --- a/transformer_engine/common/util/string.h +++ b/transformer_engine/common/util/string.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/string_header.h.in b/transformer_engine/common/util/string_header.h.in index adbbb90d73..b9fa83a94f 100644 --- a/transformer_engine/common/util/string_header.h.in +++ b/transformer_engine/common/util/string_header.h.in @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/system.cpp b/transformer_engine/common/util/system.cpp index 0659061b47..502dced9fc 100644 --- a/transformer_engine/common/util/system.cpp +++ b/transformer_engine/common/util/system.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/system.h b/transformer_engine/common/util/system.h index 67626f7167..e3a7164932 100644 --- a/transformer_engine/common/util/system.h +++ b/transformer_engine/common/util/system.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 8653bf45a4..faf3ea0a61 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 6703ce728c..6267baf19e 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/utils.py b/transformer_engine/common/utils.py index 6fd9d141b4..a808e1571f 100644 --- a/transformer_engine/common/utils.py +++ b/transformer_engine/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """The utilities for Transformer Engine""" diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 05adbd624c..31f597c37f 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Transformer Engine bindings for JAX""" diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 53451b6a78..997d4657df 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX multi-head attention modules""" diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 579daa8e41..dfb68c113c 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Python interface for c++ extensions""" diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 7f09e6f900..4a29fce2c4 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for activation""" diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f3dfca21ef..b4bf1c6fd6 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for attention""" diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3715e6f20c..1f148c86ab 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE base custom ops""" diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 1075030a0d..6739ac8bda 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom call""" diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 1f13484b98..3ec6502152 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE miscellaneous for custom ops""" diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 8ad7ee4fcb..d7512b0e70 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for normalization""" diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 062bbbf0fb..c3ea8cb7aa 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for quantization""" diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 67053ecd8e..5c55dd3672 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for softmax""" diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 2338572e30..d07b6944fb 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for transpose""" diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 64f3c467b6..6c3e2aa97d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index a2090bceba..41a6846a7c 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 4bde10fc46..dc857aa22c 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/cudnn.cpp b/transformer_engine/jax/csrc/extensions/cudnn.cpp index 95f505e226..19fe33b818 100644 --- a/transformer_engine/jax/csrc/extensions/cudnn.cpp +++ b/transformer_engine/jax/csrc/extensions/cudnn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index 8b627aad35..f991aeea18 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index d886064cae..ab1d34cf5a 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index 357a5679db..b1445e5bed 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 7f6179e91c..7ccfc85e8e 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 845eb844e2..95b33708f0 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index ccc6921f43..151a1d869a 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a986b91b30..9c92fe8b33 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index d08368657e..569dfd3baa 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index f54ebefcb0..1cf281e64b 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 8480081a68..516930c529 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/utils.cu b/transformer_engine/jax/csrc/utils.cu index 8ca34013b3..2229c85165 100644 --- a/transformer_engine/jax/csrc/utils.cu +++ b/transformer_engine/jax/csrc/utils.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 32de33bac9..01d950e168 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/dot.py b/transformer_engine/jax/dot.py index 8981af8b7c..cb8722e089 100644 --- a/transformer_engine/jax/dot.py +++ b/transformer_engine/jax/dot.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX te modules""" diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 6655091caa..f386bdce22 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Transformer Engine bindings for JAX""" diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8b13c47cd4..7aa14fb1ba 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index cb71188221..e343e9d823 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 5df8ce4386..e7ee350b46 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 4f2e83d9a2..2f120443dd 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX layernorm modules""" diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index bbf0b0f52b..c2d76c1fd3 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX MLP modules""" diff --git a/transformer_engine/jax/praxis/__init__.py b/transformer_engine/jax/praxis/__init__.py index 5be51a6d71..5352f1f53b 100644 --- a/transformer_engine/jax/praxis/__init__.py +++ b/transformer_engine/jax/praxis/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Praxis related Modules""" diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index e5649bfe7c..ce407f94fc 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 2ae212afb9..f441834355 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index c2219e3ba9..0f69939f36 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index f2da288be5..c24e550198 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index c63ee85e5d..9b32002388 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX softmax modules""" diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 50cf2186d6..583c4a7a7a 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py index 69d3859b8f..dee8a70c38 100644 --- a/transformer_engine/paddle/constants.py +++ b/transformer_engine/paddle/constants.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Constants""" diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index 281be66a8c..293c62a2fd 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """TE FP8 extensions and GEMMs""" diff --git a/transformer_engine/paddle/csrc/common.cpp b/transformer_engine/paddle/csrc/common.cpp index 5e35a28a6b..d65fbb2b50 100644 --- a/transformer_engine/paddle/csrc/common.cpp +++ b/transformer_engine/paddle/csrc/common.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 9b7e3d767a..83737c0d21 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index b35b4434db..460f4575e6 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/paddle/csrc/extensions.cpp b/transformer_engine/paddle/csrc/extensions.cpp index 128b7e2856..44ad2e7511 100644 --- a/transformer_engine/paddle/csrc/extensions.cpp +++ b/transformer_engine/paddle/csrc/extensions.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py index 75630ed28e..0e91341b80 100644 --- a/transformer_engine/paddle/distributed.py +++ b/transformer_engine/paddle/distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Methods needed for distributed training.""" diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index b9b315a150..7313a81975 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """FP8 utilities for TransformerEngine""" diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py index a880ca8107..06a9355e72 100644 --- a/transformer_engine/paddle/fp8_buffer.py +++ b/transformer_engine/paddle/fp8_buffer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """FP8 meta buffer for FP8 amax reduction""" diff --git a/transformer_engine/paddle/layer/__init__.py b/transformer_engine/paddle/layer/__init__.py index 58eb6a7c56..4d81ca231a 100644 --- a/transformer_engine/paddle/layer/__init__.py +++ b/transformer_engine/paddle/layer/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Layer level Paddle APIs""" diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 3ff5a42ff5..d3b0950dee 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Attntion API""" diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py index adbd1ce269..a854bb70db 100644 --- a/transformer_engine/paddle/layer/base.py +++ b/transformer_engine/paddle/layer/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Base modules and utilities for TransformerEngine Paddle API""" diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index 208e39ea03..be12b6534f 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Linear API""" diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index c39ad29957..57c91238e6 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """LayerNormLinear API""" diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 32f837183c..069fb82c69 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """LayerNormMLP API""" diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py index af35955a1c..78b22ac7e4 100644 --- a/transformer_engine/paddle/layer/linear.py +++ b/transformer_engine/paddle/layer/linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Linear API""" diff --git a/transformer_engine/paddle/layer/rmsnorm.py b/transformer_engine/paddle/layer/rmsnorm.py index 1afc3d9759..23e406e3fb 100644 --- a/transformer_engine/paddle/layer/rmsnorm.py +++ b/transformer_engine/paddle/layer/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """RMSNorm API""" diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py index 11549364fe..971be68167 100644 --- a/transformer_engine/paddle/layer/softmax.py +++ b/transformer_engine/paddle/layer/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Fused scaled masked softmax functions""" diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index 4a9c2c38dc..feb79c0caa 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Transformer""" diff --git a/transformer_engine/paddle/profile.py b/transformer_engine/paddle/profile.py index 67d9afcb6f..d58679aea1 100644 --- a/transformer_engine/paddle/profile.py +++ b/transformer_engine/paddle/profile.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Utils for profiling""" diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py index 1d64ad0de0..5551583736 100644 --- a/transformer_engine/paddle/recompute.py +++ b/transformer_engine/paddle/recompute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Methods needed for recompute.""" diff --git a/transformer_engine/paddle/setup.py b/transformer_engine/paddle/setup.py index 5b1d1a1e04..c80f21a01d 100644 --- a/transformer_engine/paddle/setup.py +++ b/transformer_engine/paddle/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/paddle/utils.py b/transformer_engine/paddle/utils.py index 7b9aabbf5a..4a801495ab 100644 --- a/transformer_engine/paddle/utils.py +++ b/transformer_engine/paddle/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Utility functions for Transformer Engine modules""" diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 781f9d42fd..9b51d1369a 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9268b9636e..9f08f67304 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index bf5ca4d98e..c1790313ac 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 9f3c1b2424..be911fcd95 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py index 8f7e72e268..aec972994a 100644 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index f204982aa0..534e71d134 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index cd3c01c785..9c21edccec 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 1932e9feb2..332b4e52ee 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 932bb3cafa..c55f5a9fd4 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index 50fd6b7709..f997a8a536 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py index 41dfbe2466..cf704d06ee 100644 --- a/transformer_engine/pytorch/cpp_extensions/padding.py +++ b/transformer_engine/pytorch/cpp_extensions/padding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 188c03b27c..77bf0019af 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 123758b0da..2c8736ee09 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 2ac190863c..eb97dc36eb 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 82f58b1eda..94e1f7569a 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3b49ece4a3..3abcac5bf7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7f8cff5584..48832e6994 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index c0cd2e9920..d9977f01b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index d03a10ced3..50da91a1a1 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 47f5825866..771fa4920a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index d212d13516..6b54f2de69 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 40b96a057f..250c9993fb 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 8200942643..9785602998 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 7d49a0848b..cb5e878fb2 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu index 3626bce9c2..8bc03ae375 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu index bd673b7d6e..d5d55c2872 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu index 3009a82768..5ea5c1d3d1 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 2574b84352..2124b551fd 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d975ebeeef..ca10e4d3c9 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 0c9bed45e0..f363e6e7ea 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8856553c54..165855d430 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index a130169fe7..ec75a2a8c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index acb68543d8..93be90c9f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index f373cdf83a..40f76c898c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh index e85ec3afc2..f7598da45a 100644 --- a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh +++ b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 9f31dba669..203b575a0d 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/type_shim.h b/transformer_engine/pytorch/csrc/type_shim.h index 5d5a91f9eb..8100f0e4a2 100644 --- a/transformer_engine/pytorch/csrc/type_shim.h +++ b/transformer_engine/pytorch/csrc/type_shim.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 490ac3b160..e6d63ab9e4 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py index 5bc079711a..79b839edfd 100755 --- a/transformer_engine/pytorch/export.py +++ b/transformer_engine/pytorch/export.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index c3d8709925..8554cc7443 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15f20c81e5..b1b6165777 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f44500f7f2..3853e70048 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index ed08627e95..cda3939d6f 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index ba4755efe3..5074d32aa2 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 21365398f3..2be291e4f9 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d115efedaa..8de0b733a9 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 16d40cf401..1034398875 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index d45abe0668..b0832b0848 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 08c5addcfc..65023e493b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index b42079d299..1a635afbb8 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 92b37fcb07..189464cf80 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1a651474bf..1ce24e02d8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9492725f56..5fd4dd2fc9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index bd7db1f775..d2e0d1b2ba 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/numerics_debug.py b/transformer_engine/pytorch/numerics_debug.py index bc9a5f89e0..5a73f5b61b 100644 --- a/transformer_engine/pytorch/numerics_debug.py +++ b/transformer_engine/pytorch/numerics_debug.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f65433398e..156c33210a 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index b1654add98..26bceab737 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index d6f4940c58..ae635c956a 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index a2e5a24a85..7ad6e70929 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_in_place.py index 041888f5d7..4ccbaef1c0 100644 --- a/transformer_engine/pytorch/ops/basic/add_in_place.py +++ b/transformer_engine/pytorch/ops/basic/add_in_place.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index b914d1dc6f..2dd1d1b75e 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index f466ade3a3..8b4593b934 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index ad86861114..c5178d2d91 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index eac1865566..5a73ec6c25 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py index 73179c68a6..d0466be15e 100644 --- a/transformer_engine/pytorch/ops/basic/identity.py +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 710f838581..65717d5fa5 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index db1651c184..73d08b5c7f 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 313b6e5583..e3755decd6 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index c78dbc2877..03a02786b4 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index c3b1816635..53524cdd83 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 84f05ce713..32ef242b90 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 08b9f06123..b9b5ec9508 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 123c560066..1ddd8d116c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 3afdc3a0c3..c746f21f2c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 3d994d80f0..fa7f07cb95 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 907cff1c81..dab4c8f681 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index a1b0ca6a9e..1f3635eb4b 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 8b2a04cff8..dc96c12523 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 68472f171a..8ed2702a72 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index c55e0f7c19..30367d2c5e 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 8d4fefb4c5..3240bd73d6 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index fc9bdc304a..c76f75743d 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 93f6191dfe..170c95442f 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index ee428d2417..53fa59821c 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index 191b57eab9..64ec0a28da 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 540bacbf84..90cb5cc021 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index c527ca83ef..d3b3f03e10 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index a632851a76..3950c071b6 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 9b4b2df145..54eb37ecab 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 16b7f8b623..aceaaf5d10 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 414e819f53..d356df58dc 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 92c95b56ca..550e113389 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ad5476450b..7c3da9a73f 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 947c642c2c..63b2f2cfb5 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 7c23b966c98b3e63b4d716db663f4d99dd9bbf82 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 6 Jan 2025 04:24:24 -0800 Subject: [PATCH 044/239] update license for test_paged_attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index ff4a7ea0f4..ce40473c51 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 2dbf2e1833645d11e46beaaabfac532149eb1221 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 6 Jan 2025 04:26:25 -0800 Subject: [PATCH 045/239] update kv_cache_manager license Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/kv_cache_manager_non_paged.py | 2 +- transformer_engine/pytorch/kv_cache_manager_paged.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 0a72968f67..0f9ce5da66 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index cf4cba5b71..4066538dd7 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From d2f1549abc9422242489ec9069859771c66cff67 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:09:33 -0800 Subject: [PATCH 046/239] fix build issue from previous merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index cfeb9b86b9..0b0f6dfe1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -68,13 +68,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -476,13 +476,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); From b898cbe18fb1d3414544bf8d11de5f83cadfb5db Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 8 Jan 2025 08:31:00 +0800 Subject: [PATCH 047/239] [JAX] Add THD + SWA unit tests (#1390) * Fix SWA mask for THD and forcing seqlen_kv >= seqlen_q for SWA Signed-off-by: Reese Wang * Generalize sliding window mask Signed-off-by: Reese Wang * Fix pylint Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_fused_attn.py | 60 ++++++++-------- tests/jax/utils.py | 10 +-- transformer_engine/jax/attention.py | 79 +++++++++------------- transformer_engine/jax/flax/transformer.py | 17 +++-- 4 files changed, 80 insertions(+), 86 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 01fc2b3e21..5cbbec7b04 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -148,30 +148,30 @@ def make_mask( segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ + # segment masks inv_mask = make_attention_mask( segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) + + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + + # causal mask if attn_mask_type.is_causal(): - if segment_pos_q is None: - segment_pos_q = jnp.broadcast_to( - jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape - ) - if segment_pos_kv is None: - segment_pos_kv = jnp.broadcast_to( - jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape - ) inv_causal_mask = make_attention_mask( segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) inv_mask = combine_masks(inv_causal_mask, inv_mask) - if window_size is not None: - max_seqlen_q = inv_mask.shape[-2] - max_seqlen_kv = inv_mask.shape[-1] - inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) - inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - inv_mask = combine_masks(inv_mask, inv_swa_mask) - + # sliding window mask + inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -314,13 +314,6 @@ def _get_max_segments_per_sequence(self): return self.num_segments_per_seq + 1 def _check_configs(self): - # TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference - if ( - self.qkv_layout.is_thd() - and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK - and self.window_size is not None - ): - pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.") # TODO(rewang): probably adds this in is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -432,7 +425,12 @@ def gen_valid(bs, max_seqlen, pad_ratio): return tokens, jnp.logical_not(tokens) def generate_random_segment_ids( - batch_size, sequence_length, num_segments, seed, with_segment_pad=True + batch_size, + sequence_length, + num_segments, + seed, + with_segment_pad=True, + min_segment_len=None, ): rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad @@ -448,15 +446,20 @@ def generate_random_segment_ids( current_pos = 0 segment_id = 1 - for _ in range(num_segments): - segment_size = rng.integers(1, max_segment_size + 1) + for seg_id in range(num_segments): + # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed + # TODO(rewang): Remove this constrain after cuDNN supports + min_segment_size = 1 + if min_segment_len is not None: + min_segment_size = min_segment_len[i][seg_id] + segment_size = rng.integers(min_segment_size, max_segment_size + 1) if current_pos + segment_size > sequence_length: break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: - num_valid = rng.integers(1, segment_size + 1) + num_valid = rng.integers(min_segment_size, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 @@ -473,18 +476,21 @@ def generate_random_segment_ids( self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) if self.qkv_layout == QKVLayout.T3HD: self.segment_ids_kv = self.segment_ids_q self.segment_pos_kv = self.segment_pos_q self.pad_kv = self.pad_q else: + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None if self.window_size is None else self.seqlens_q self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, + min_segment_len=min_segment_len, ) - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 3ff879e68c..9cb02bc555 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -919,14 +919,14 @@ def apply_swa_mask( """Apply the sliding window mask to a given mask""" _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None + batch = original_mask.shape[0] max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] - swa_mask = make_swa_mask( - max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype - ) + pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) + pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) + swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype) # In swa_mask and original_mask 0 is masked out - swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) - new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask) + new_mask = jnp.where(original_mask == 1, swa_mask, original_mask) return new_mask diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 997d4657df..7b6c605236 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -147,59 +147,44 @@ class CPStrategy(Enum): def make_swa_mask( - max_seqlen_q: int, - max_seqlen_kv: int, + segment_pos_q: jnp.ndarray, + segment_pos_kv: jnp.ndarray, window_size: Optional[Tuple[int, int]] = None, - attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK, dtype: jax.typing.DTypeLike = jnp.float32, ): """ - Generate sliding window mask. `True` or `1` means keep the element. - - For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type, - the sliding window diagonal is aligned to the bottom right corner, and for other - mask types, the top left corner. - - Parameters - ---------- - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - window_size: Optional[Tuple[int, int]] = None - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Negative number in window size means infinity window. - `None` means no sliding window. - attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK - dtype: jax.typing.DTypeLike, default=jnp.float32 - The mask data type. - Returns - ---------- - swa_mask: jax.numpy.tensor - Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions - that will get attention, value 0 are the masked out positions. + Generate a sliding window mask (1 = attend, 0 = masked). + + Args: + segment_pos_q (jnp.ndarray): + Query positions within each segment. For example, a batch with segment_ids = + [[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos = + [[0, 1, 2, 0, 1, 2, 3, 4]]. + segment_pos_kv (jnp.ndarray): + Key/value positions within each segment. + window_size (Optional[Tuple[int, int]], optional): + Sliding window size for local attention, where query at position i attends to keys + in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an + infinite window; None means no sliding window. + Defaults to None. + dtype (jax.typing.DTypeLike, optional): + Mask data type. Defaults to jnp.float32. + + Returns: + jnp.ndarray: + The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv]. """ - swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) - if window_size is None: - return swa_mask - left_window, right_window = window_size - if attn_mask_type.is_bottom_right(): - if left_window < 0: - left_window = max_seqlen_kv - if right_window < 0: - right_window = max_seqlen_kv - bottom_right_shift = max_seqlen_kv - max_seqlen_q - swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift) - swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift) + if window_size is not None: + left_window, right_window = window_size else: - if left_window < 0: - left_window = max_seqlen_q - if right_window < 0: - right_window = max_seqlen_q - swa_mask = jnp.triu(swa_mask, k=-left_window) - swa_mask = jnp.tril(swa_mask, k=right_window) - return swa_mask + left_window = right_window = jnp.inf + left_window = jnp.inf if left_window < 0 else left_window + right_window = jnp.inf if right_window < 0 else right_window + pos_q = jnp.expand_dims(segment_pos_q, axis=-1) + pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2) + inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window) + inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3) + return inv_swa_mask.astype(dtype) def canonicalize_attn_mask_type(attn_mask_type: str): diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index e343e9d823..cf2b13d074 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -194,15 +194,18 @@ def __call__( if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias - def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array: + def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask""" + batch = original_mask.shape[0] max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] - swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type) - # In swa_mask 0 is masked out, in original_mask 1 is masked out - swa_mask = 1 - swa_mask.astype(original_mask.dtype) - swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) - new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask) + # TODO(rewang): Support THD format pos + pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) + pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) + # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out + inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype) + swa_mask = 1 - inv_swa_mask + new_mask = jnp.where(original_mask == 0, swa_mask, original_mask) return new_mask def convert_to_softmax_type(attn_mask_type, mask): @@ -213,7 +216,7 @@ def convert_to_softmax_type(attn_mask_type, mask): if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: mask = None if mask is not None: - mask = apply_swa_mask(attn_mask_type, mask) + mask = apply_swa_mask(mask) # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask From 61cf102006878530bf66a5dd4275016e148430aa Mon Sep 17 00:00:00 2001 From: Liyuan Liu Date: Tue, 7 Jan 2025 18:07:14 -0800 Subject: [PATCH 048/239] bug fix for using `return_layernorm_output=True` (#1382) the current implementation would release the output of ln, leading to an error if setting `return_layernorm_output=True`. Signed-off-by: Liyuan Liu Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1ce24e02d8..7bcbb1eb7d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -373,7 +373,7 @@ def forward( ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) - if not is_grad_enabled: + if not is_grad_enabled and not return_layernorm_output: clear_tensor_data(ln_out_total) if bias_gelu_nvfusion: From a4cb1d177637bae32d16d1e600902538230daf03 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Wed, 8 Jan 2025 10:09:06 -0600 Subject: [PATCH 049/239] [JAX] Correct fused attention output after each step of ring attention (#1393) Correct fused attention output after each step to reduce intermediate memory use. Signed-off-by: Michael Goldfarb --- tests/jax/test_distributed_fused_attn.py | 8 +-- .../jax/cpp_extensions/attention.py | 61 ++++++++++++------- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index ccbdf9407c..5a41911691 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -401,7 +401,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout): raise ValueError(f"Unsupported {qkv_layout=}") return qkv_args - def impl_test_contex_parallel_attn( + def impl_test_context_parallel_attn( self, device_count, mesh_shape, @@ -583,7 +583,7 @@ def grad_func(func, *args, **kwargs): assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) - def test_contex_parallel_allgather_attn( + def test_context_parallel_allgather_attn( self, device_count, mesh_shape, @@ -596,7 +596,7 @@ def test_contex_parallel_allgather_attn( qkv_layout, load_balanced, ): - return self.impl_test_contex_parallel_attn( + return self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -623,7 +623,7 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, ): - return self.impl_test_contex_parallel_attn( + return self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index b4bf1c6fd6..3a116ffb63 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1549,12 +1549,19 @@ def permute_kv(self, kv, cp_perm): """Permutes kv around the ring as described by cp_perm.""" return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm) - def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step): - """Apply soft max correction after an attention step.""" - max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step) - min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step) - new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale)) - return new_softmax_aux + @staticmethod + def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux): + """ + Corrects the output and softmax_aux tensor after each iteration of ring attention. + + See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for + derivation of this equation. + """ + new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose( + 0, 2, 1, 3 + ) * (output - partial_output) + new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux) + return new_out, new_aux def adjust_seqlen(self, seqlen, max_seqlen, idx): """Adjust the sequence length per step.""" @@ -1615,10 +1622,7 @@ def ring_attn_fwd_impl( cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] - output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype) - softmax_aux_per_steps = jnp.zeros( - (cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32 - ) + output = jnp.zeros(q.shape).astype(jnp.float32) softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32) # RNG shape should be the shared shape. This is unused for ring attention as we do not @@ -1627,7 +1631,7 @@ def ring_attn_fwd_impl( rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): - kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry + kv, output, softmax_aux = carry # Send KV block to next step so we can overlap compute. kv_next = helper.permute_kv(kv, cp_perm) @@ -1718,25 +1722,38 @@ def jax_cond_wrap(): else: output_per_step, softmax_aux_per_step = no_mask_compute() - softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step) - output_per_steps = output_per_steps.at[idx].set(output_per_step) - softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step) + def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step): + # No correction done here but we cast outputs to float32 and perform reduction + # in full precision. + # pylint: disable=unused-argument + return output_per_step.astype(jnp.float32), softmax_aux_per_step - return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps) + def correction(output, softmax_aux, output_per_step, softmax_aux_per_step): + return helper.correct_output_and_softmax_aux( + output, softmax_aux, output_per_step, softmax_aux_per_step + ) - carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) + # first step there is no correction we get initial output and stats + output, softmax_aux = lax.cond( + (idx == 0), + skip_correction, + correction, + output, + softmax_aux, + output_per_step, + softmax_aux_per_step, + ) + + return (kv_next, output, softmax_aux) + + carry = (kv, output, softmax_aux) if helper.use_scanloop(): carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) else: for i in range(0, cp_size): carry = scan_kv_block(i, carry) - (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry + (kv, output, softmax_aux) = carry - output = jnp.zeros(q.shape).astype(jnp.float32) - for idx in range(cp_size): - output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp( - softmax_aux_per_steps[idx] - softmax_aux - ).transpose(0, 2, 1, 3) output = output.astype(q.dtype) return output, softmax_aux, rng_state From 560bccf8d5116c776a1ea3af732712f4cd7f9e4e Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:09:48 -0800 Subject: [PATCH 050/239] clean CP implementation for flash attention and cuDNN 9.6 (#1387) * make pad_between_seqs check do not consider padding at the end Signed-off-by: Xiaowei Ren * change CP THD test to make it consider 0-length sequence Signed-off-by: Xiaowei Ren * minor change to flash func name Signed-off-by: Xiaowei Ren * only use varlen func of flash attention while qkv_format is THD Signed-off-by: Xiaowei Ren * try to converge code of flash and fused attentions Signed-off-by: Xiaowei Ren * fix bwd compute with P2P Signed-off-by: Xiaowei Ren * remove redundant out_per_step view Signed-off-by: Xiaowei Ren * enable cudnn>9.6 and THD+GQA Signed-off-by: Xiaowei Ren * enable CP with FusedAttn+SWA+All_Gather Signed-off-by: Xiaowei Ren * enable CP with FusedAttn+SWA+All_Gather Signed-off-by: Xiaowei Ren * code cleaning for cu_seqlens Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix some pylint error Signed-off-by: Xiaowei Ren * minor import change for pylint Signed-off-by: Xiaowei Ren * more fix for pylint Signed-off-by: Xiaowei Ren * fix lse_seqlen in thd out correction Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../fused_attn/run_fused_attn_with_cp.py | 26 +- .../fused_attn/test_fused_attn_with_cp.py | 10 +- transformer_engine/pytorch/attention.py | 838 +++++++++--------- .../pytorch/csrc/extensions/attention.cu | 4 +- 4 files changed, 442 insertions(+), 436 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 7a4d953840..1fae9e99f2 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -163,12 +163,10 @@ def run_dpa_with_cp( torch.tensor([q_input_shape[0]], dtype=torch.int32), ] ).cuda() - if kernel_backend == "FlashAttention": - cu_seqlens_q = cu_seqlens_q_padded[:-1] - else: - cu_seqlens_q = torch.cat( - [torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)] - ).cuda() + cu_seqlens_q = torch.clone(cu_seqlens_q_padded) + if kernel_backend == "FusedAttention": + cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() + cu_seqlens_q[-1] = cu_seqlens_q[-2] cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: @@ -204,10 +202,8 @@ def run_dpa_with_cp( core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] - ), + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) @@ -276,10 +272,8 @@ def run_dpa_with_cp( core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] - ), + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) @@ -311,7 +305,7 @@ def run_dpa_with_cp( dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) @@ -327,7 +321,7 @@ def run_dpa_with_cp( ).item() == 0 ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True ) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 73994e1873..9866591e8d 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") - if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0): - pytest.skip("THD format is not supported for cuDNN 9.6+!") config = model_configs_fused_attn[model] - if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: - pytest.skip("THD format does not support QGA/MQA yet!") if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if qkv_format == "thd" and "a2a" in cp_comm_type: pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") - if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": - pytest.skip( - "Sliding window attention only can be supported with the implementation of QKVO A2A!" - ) if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" @@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("FP8 attention cannot work with bias yet!") if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("FP8 attention cannot work with sliding window yet!") + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9f08f67304..3f0267affb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -125,11 +125,13 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_5_7_plus = False _flash_attn_2_6_0_plus = False +flash_attn_cuda_bwd = None flash_attn_func = None flash_attn_varlen_func = None -flash_attn_varlen_fwd = None -flash_attn_varlen_bwd = None -flash_attn_cuda_bwd = None +_flash_attn_fwd = None +_flash_attn_bwd = None +_flash_attn_varlen_fwd = None +_flash_attn_varlen_bwd = None try: _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) @@ -141,14 +143,16 @@ def _get_supported_versions(version_min, version_max): ) else: if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd + from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_forward as flash_attn_varlen_fwd, + _flash_attn_varlen_forward as _flash_attn_varlen_fwd, ) from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_backward as flash_attn_varlen_bwd, + _flash_attn_varlen_backward as _flash_attn_varlen_bwd, ) - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd _flash_attn_is_installed = True _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") @@ -195,11 +199,13 @@ def _get_supported_versions(version_min, version_max): from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_forward as flash_attn_varlen_fwd_v3, + _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, ) from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3, + _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, ) _flash_attn_3_is_installed = True @@ -602,12 +608,6 @@ def get_attention_backend( "Disabling FusedAttention as it does not support context parallelism with MLA" ) use_fused_attention = False - elif cudnn_version >= (9, 6, 0) and qkv_format == "thd": - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with THD for" - " cuDNN 9.6+" - ) - use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends @@ -1804,12 +1804,20 @@ def forward( else: qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) - pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal( + cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1] + ) + pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal( + cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1] + ) max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size - cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size - cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size + cu_seqlens_q_padded = ( + None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size + ) + cu_seqlens_kv_padded = ( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size + ) cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] @@ -1882,9 +1890,6 @@ def forward( elif qkv_format == "sbhd": # [s, b, np, hn] -> [2, s//2, b, np, hn] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] - total_tokens_kv = None if qkv_format != "thd" else k.shape[0] - # remove padded tokens at the end - k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]] if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -1907,17 +1912,27 @@ def forward( ) assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" - softmax_lse_in_packed_format = not use_fused_attention and ( - _flash_attn_2_6_0_plus or _use_flash_attn_3 - ) + softmax_lse_in_packed_format = False + if qkv_format == "thd": + if use_fused_attention: + softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + else: + softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3 + flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_3_plus: @@ -1943,7 +1958,7 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if use_fused_attention and qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) else: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) @@ -1991,31 +2006,31 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + q_inputs[i % 2] = q if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -2060,18 +2075,27 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv, + ] fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=True, **fa_forward_kwargs, ) @@ -2084,7 +2108,7 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2095,25 +2119,26 @@ def forward( True, False, ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] + elif qkv_format == "thd": + q_inputs[i % 2] = q + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous() - elif qkv_format == "thd": - q_inputs[i % 2] = q - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() @@ -2156,28 +2181,29 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) + fa_forward_args_thd = [] if qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() - # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv // 2, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_forward_kwargs["window_size"] = (-1, -1) fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv // 2, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) @@ -2190,7 +2216,7 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2201,28 +2227,29 @@ def forward( True, True, ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_inputs[i % 2] = q[:, 1, ...] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_inputs[i % 2] = q[1] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...].contiguous() - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1].contiguous() - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + q_inputs[i % 2] = q_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -2271,28 +2298,29 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + fa_forward_args_thd = [] if qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) - else: - # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] - q_inputs[i % 2] = ( - q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - ) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q // 2, + max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_forward_kwargs["window_size"] = (-1, -1) fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q // 2, - max_seqlen_kv, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) @@ -2305,7 +2333,7 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2316,7 +2344,7 @@ def forward( True, True, ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if attn_bias is not None: @@ -2363,18 +2391,27 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, sq, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv, + ] fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) @@ -2389,13 +2426,13 @@ def forward( flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] + # [b, np, sq, 1] -> [b, np, sq] or + # [t, np, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) - if qkv_format != "thd" and softmax_lse_in_packed_format: - # [np, t] -> [np, b, sq] - softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view( - q.shape[-2], q.shape[0], -1 - ) + if softmax_lse_in_packed_format: + softmax_lse_per_step[i - 1] = ( + softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + ) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: @@ -2410,8 +2447,7 @@ def forward( out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": - # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format - # [np, b, sq] -> [np, b, 2, sq//2] lse in packed format + # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) @@ -2439,16 +2475,6 @@ def forward( softmax_lse = softmax_lse.to(torch.float) for i in range(cp_size): - out_ = None - if qkv_format == "bshd": - out_per_step[i] = out_per_step[i].view( - out.shape[0], -1, *out.shape[-2:] - ) # pylint: disable=used-before-assignment - out_ = out[:, 1, ...] - elif qkv_format == "sbhd": - out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) - out_ = out[1] - if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction( @@ -2471,6 +2497,7 @@ def forward( ) else: if qkv_format in ["bshd", "sbhd"]: + out_ = out.select(seq_dim, 1) flash_attn_fwd_out_correction( out_, out_per_step[i], @@ -2490,9 +2517,6 @@ def forward( softmax_lse_in_packed_format, ) - if qkv_format != "thd" and softmax_lse_in_packed_format: - # [np, b, sq] -> [np, t] - softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1) kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) @@ -2587,7 +2611,6 @@ def forward( ctx.cp_global_ranks = cp_global_ranks ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p - ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale @@ -2597,6 +2620,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 @@ -2646,14 +2670,10 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None - softmax_lse_in_packed_format = not ctx.use_fused_attention and ( - _flash_attn_2_6_0_plus or _use_flash_attn_3 - ) - if causal: - if ctx.qkv_format == "thd" or softmax_lse_in_packed_format: + if ctx.qkv_format == "thd": softmax_lse_ = tex.thd_read_second_half_lse( - softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format + softmax_lse, cu_seqlens_q_padded, ctx.softmax_lse_in_packed_format ) else: # [b, np, sq] -> [b, np, 2, sq//2] @@ -2661,13 +2681,20 @@ def backward(ctx, dout): *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() - if ctx.use_fused_attention: - # [b, np, sq//2] -> [b, np, sq//2, 1] - softmax_lse_.unsqueeze_(-1) + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() + # [b, np, sq//2] -> [b, np, sq//2, 1] or + # [t//2, np] -> [t//2, np, 1] + softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: - # [b, np, sq] -> [b, np, sq, 1] + if ctx.softmax_lse_in_packed_format: + softmax_lse = softmax_lse.transpose(0, 1).contiguous() + # [b, np, sq] -> [b, np, sq, 1] or + # [t, np] -> [t, np, 1] softmax_lse.unsqueeze_(-1) + dq = None dout_dtype = dout.dtype fused_attn_backend = None fused_attn_qkv_dtype = None @@ -2715,8 +2742,6 @@ def backward(ctx, dout): dout_scale_inv = dout._scale_inv dout = dout._data dq = torch.empty_like(q) - if ctx.qkv_format == "thd" and causal: - dq[cu_seqlens_q_padded[-1] :].fill_(0) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2760,10 +2785,16 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd_v3 + else: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None @@ -2808,32 +2839,29 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - dk_, dv_ = None, None + q_, kv_, out_, dout_ = None, None, None, None + dq_, dk_, dv_ = None, None, None if ctx.fp8 and ctx.use_fused_attention: fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + q_, kv_, out_, dout_ = q, kv, out, dout if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_ = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - out_ = out.view(-1, *out.shape[-3:]) - dout_ = dout.view(-1, *dout.shape[-3:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -2869,15 +2897,16 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, 0) if not _use_flash_attn_3: @@ -2885,42 +2914,36 @@ def backward(ctx, dout): flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=True, **fa_backward_kwargs, ) elif i >= (cp_size - rank - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] + elif ctx.qkv_format == "thd": + q_, out_, dout_ = q, out, dout + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0, ...].contiguous() - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_ = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0].contiguous() - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - out_ = out.view(-1, *out.shape[-3:]) - dout_ = dout.view(-1, *dout.shape[-3:]) - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + kv_ = kv_.contiguous() if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -2958,19 +2981,16 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - if ctx.qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - else: - # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q, + ctx.max_seqlen_kv // 2, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, -1) if not _use_flash_attn_3: @@ -2978,44 +2998,37 @@ def backward(ctx, dout): flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) else: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_, out_, dout_ = q[1], out[1], dout[1] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_, out_, dout_ = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q, out, dout] + ] + kv_ = kv if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, 1, ...].contiguous() - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - out_ = out[:, 1, ...].contiguous() - dout_ = dout[:, 1, ...].contiguous() - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[1].contiguous() - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - out_ = out[1].contiguous() - dout_ = dout[1].contiguous() - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) - kv_ = kv + q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] if ctx.fp8: aux_ctx_tensors = [ softmax_lse_, @@ -3053,23 +3066,16 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - if ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = [] if ctx.qkv_format == "thd": - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) - dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q // 2, + ctx.max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, -1) if not _use_flash_attn_3: @@ -3077,17 +3083,14 @@ def backward(ctx, dout): flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse_, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) @@ -3124,50 +3127,41 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - # [b, sq, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) - dkv_ = torch.empty_like(kv_) - # [b, sq, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) + dq_ = torch.empty_like(q) + dkv_ = torch.empty_like(kv) + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, -1) if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( - dout_, - q_, - kv_[0], - kv_[1], - out_, + dout, + q, + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + out, softmax_lse, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) if ctx.fp8: dq = dq_fp8[(rank + i + 1) % cp_size] - if i >= (cp_size - rank - 1) or not causal: - # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal - # [b*sq, np, hn] -> [b, sq, np, hn] if not causal + if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): + # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or + # [sq, b, np, hn] -> [2, sq//2, b, np, hn] dq_ = dq_.view(*dq.shape) - else: - if ctx.qkv_format == "bshd": - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_ = dq_.view(dq.shape[0], *dq.shape[2:]) - elif ctx.qkv_format == "sbhd": - # [b*sq//2, np, hn] -> [sq//2, b, np, hn] - dq_ = dq_.view(-1, *dq.shape[-3:]) if ctx.fp8: if i >= (cp_size - rank - 1) or not causal: @@ -3242,24 +3236,21 @@ def backward(ctx, dout): else: dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment if ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] - dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) - elif ctx.qkv_format == "sbhd": - # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn] - dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:]) - else: - # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal - # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal - dkv_ = dkv_.view(*dkv.shape) + dkv_ = _combine_tensors([dk_, dv_], -2) + elif ctx.qkv_format == "thd": + dkv_ = torch.cat( + (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 + ) # pylint: disable=used-before-assignment + if ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) + dkv_ = dkv_.movedim(-3, 0) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv_ = dkv_.view(*dkv.shape) if ctx.fp8: if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): @@ -3341,13 +3332,9 @@ def backward(ctx, dout): # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) - if ctx.qkv_format == "thd": - dkv_ = torch.empty( - 2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device - ) - dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv) - dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) - dkv = dkv_ + if ctx.qkv_format == "thd" and not ctx.use_fused_attention: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: dq, dkv = [ @@ -3494,9 +3481,15 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = _flash_attn_fwd_v3 else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_4_plus: @@ -3514,8 +3507,11 @@ def forward( max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) - cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_q = cu_seqlens_q // (2 * cp_size) + cu_seqlens_q_padded = ( + None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size) + ) # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) @@ -3570,9 +3566,10 @@ def forward( kv_seq_range_per_step[i][1], ) max_seqlen_kv_ = seq_end_idx - seq_start_idx - cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device - ) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] @@ -3599,15 +3596,19 @@ def forward( window_size=window_size_per_step[i], ) else: - q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv_, + ] fa_outputs = flash_attn_fwd( q_, k_, v_, - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv_, + *fa_forward_args_thd, causal=causal, window_size=window_size_per_step[i], **fa_forward_kwargs, @@ -3620,9 +3621,9 @@ def forward( if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape)) + out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape)) + out[i - 1].copy_(out_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3711,10 +3712,16 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd_v3 + else: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None @@ -3764,11 +3771,17 @@ def backward(ctx, dout): deterministic=ctx.deterministic, ) else: - batch_size = k_.shape[0] - q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] ] + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + ctx.max_seqlen_q, + max_seqlen_kv, + ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] flash_attn_bwd( @@ -3781,21 +3794,11 @@ def backward(ctx, dout): dq_per_step[i], dk_per_step[i], dv_per_step[i], - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - ctx.max_seqlen_q, - max_seqlen_kv, + *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, window_size=window_size_per_step[i], **fa_backward_kwargs, ) - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) - # [b*s_range, np, hn] -> [b, s_range, np, hn] - dk_per_step[i], dv_per_step[i] = [ - x.view(batch_size, -1, *x.shape[-2:]) - for x in [dk_per_step[i], dv_per_step[i]] - ] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3916,10 +3919,16 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = window_size else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_3_plus: @@ -4025,24 +4034,25 @@ def forward( **fp8_meta_kwargs, ) else: - # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] - q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] fa_outputs = flash_attn_fwd( q, k, v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, + *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) out, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] - # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) out = flash_attn_a2a_communicate( @@ -4214,11 +4224,17 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd_v3 + else: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p if _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = ctx.window_size @@ -4255,8 +4271,15 @@ def backward(ctx, dout): ) else: softmax_lse, rng_state = aux_ctx_tensors - out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( @@ -4269,14 +4292,10 @@ def backward(ctx, dout): dq, dk, dv, - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + *fa_backward_args_thd, causal=causal, **fa_backward_kwargs, ) - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) dq, dk, dv = flash_attn_a2a_communicate( @@ -4400,18 +4419,17 @@ def attn_forward_func_with_cp( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" ) - assert ( + assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" + ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) - assert ( - not sliding_window_attn - or cp_comm_type == "a2a" - or (cp_comm_type == "all_gather" and not use_fused_attention) - ), "The context parallel running configs cannot support sliding window attetnion!" + assert not sliding_window_attn or cp_comm_type in [ + "a2a", + "all_gather", + ], "The context parallel running configs cannot support sliding window attetnion!" args = [ is_training, @@ -5419,8 +5437,8 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, + cu_seqlens_q if qkv_format == "thd" else None, + cu_seqlens_kv if qkv_format == "thd" else None, self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -7215,7 +7233,7 @@ def forward( and cu_seqlens_kv is not None ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" - if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None: + if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None): cu_seqlens_q_padded = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_kv @@ -8151,10 +8169,10 @@ def forward( pad_between_seqs = ( cu_seqlens_q_padded is not None - and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) + and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) ) or ( cu_seqlens_kv_padded is not None - and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) ) attention_params = AttentionParams( diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50da91a1a1..f947930d23 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1537,10 +1537,10 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ int batch, lse_seqlen; if (lse_packed) { batch = cu_seqlens.size(0) - 1; - lse_seqlen = total_tokens; + lse_seqlen = lse.size(1); NVTE_CHECK(lse.size(0) == num_heads); - NVTE_CHECK(lse.size(1) == lse_seqlen); + NVTE_CHECK(lse_seqlen >= total_tokens); NVTE_CHECK(lse_per_step.size(0) == num_heads); NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); } else { From 7b861e75d7590f98a3450b85fe90a35343947b75 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Fri, 10 Jan 2025 01:47:14 -0800 Subject: [PATCH 051/239] Take token count quantization of fused attention into consideration for CP results correction (#1396) * fix second half lse shape Signed-off-by: Xiaowei Ren * bug fixes Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/fused_attn/thd_utils.h | 17 ++++--- transformer_engine/pytorch/attention.py | 13 ++++- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cu | 51 +++++++++++-------- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/common/fused_attn/thd_utils.h index 91f5f7bac1..ec265e4366 100644 --- a/transformer_engine/common/fused_attn/thd_utils.h +++ b/transformer_engine/common/fused_attn/thd_utils.h @@ -69,7 +69,7 @@ struct ReadLseFunctor { template __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int total_tokens) { + int num_heads, int lse_seqlen, int second_half_lse_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / 2; @@ -85,15 +85,15 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { size_t idx, half_idx; if constexpr (lse_packed) { - idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; - half_idx = head_id * total_tokens / 2 + token_id; + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * second_half_lse_seqlen + token_id; } else { size_t row = static_cast(seq_id) * num_heads + head_id; int col = token_id - cu_seqlens_s[seq_id]; int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * total_tokens + col + seq_len; - half_idx = row * total_tokens / 2 + col; + idx = row * lse_seqlen + col + seq_len; + half_idx = row * second_half_lse_seqlen + col; } Functor::run(lse, half_lse, idx, half_idx); @@ -108,7 +108,8 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, template __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int lse_seqlen) { + int num_heads, int dim_per_head, int lse_seqlen, + int lse_per_step_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); @@ -128,13 +129,13 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float if constexpr (lse_packed) { idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; + idx_per_step = head_id * lse_per_step_seqlen + token_id; } else { size_t row = static_cast(seq_id) * num_heads + head_id; int col = token_id - cu_seqlens_s[seq_id]; int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; idx = row * lse_seqlen + col + seq_len * only_second_half; - idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; + idx_per_step = row * lse_per_step_seqlen + col; } float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3f0267affb..55c8a2fcf2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2473,6 +2473,10 @@ def forward( torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + second_half_lse_seqlen = None + if causal and rank < (cp_size - 1): + second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + softmax_lse = softmax_lse.to(torch.float) for i in range(cp_size): if i <= rank or not causal: @@ -2621,6 +2625,7 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format + ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 @@ -2670,10 +2675,14 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None - if causal: + softmax_lse_ = None + if causal and ctx.second_half_lse_seqlen is not None: if ctx.qkv_format == "thd": softmax_lse_ = tex.thd_read_second_half_lse( - softmax_lse, cu_seqlens_q_padded, ctx.softmax_lse_in_packed_format + softmax_lse, + cu_seqlens_q_padded, + ctx.softmax_lse_in_packed_format, + ctx.second_half_lse_seqlen, ) else: # [b, np, sq] -> [b, np, 2, sq//2] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3abcac5bf7..67fd1caf5b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -444,7 +444,7 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st const at::Tensor &cu_seqlens, bool lse_packed); at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - bool lse_packed); + bool lse_packed, int second_half_lse_seqlen); void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index f947930d23..9c9ffdb1a7 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1420,7 +1420,7 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch, num_heads, total_tokens; + int batch, num_heads, lse_seqlen, second_half_lse_seqlen; if (lse_packed) { NVTE_CHECK(lse.dim() == 2); @@ -1428,48 +1428,50 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st batch = cu_seqlens.size(0) - 1; num_heads = lse.size(0); - total_tokens = lse.size(1); + lse_seqlen = lse.size(1); + second_half_lse_seqlen = lse_per_step.size(1); NVTE_CHECK(lse_per_step.size(0) == num_heads); - NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2); + NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); } else { NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(lse_per_step.dim() == 3); batch = lse.size(0); num_heads = lse.size(1); - total_tokens = lse.size(2); + lse_seqlen = lse.size(2); + second_half_lse_seqlen = lse_per_step.size(2); NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2); + NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); } constexpr unsigned int block = 256; - unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; if (lse_packed) { thd_lse_kernel <<>>( lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, total_tokens); + batch, num_heads, lse_seqlen, second_half_lse_seqlen); } else { thd_lse_kernel <<>>( lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, total_tokens); + batch, num_heads, lse_seqlen, second_half_lse_seqlen); } } at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - bool lse_packed) { + bool lse_packed, int second_half_lse_seqlen) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch, num_heads, total_tokens; + int batch, num_heads, lse_seqlen; std::vector shape; if (lse_packed) { @@ -1477,37 +1479,40 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ batch = cu_seqlens.size(0) - 1; num_heads = lse.size(0); - total_tokens = lse.size(1); + lse_seqlen = lse.size(1); - shape = {num_heads, total_tokens / 2}; + NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); + + shape = {num_heads, second_half_lse_seqlen}; } else { NVTE_CHECK(lse.dim() == 3); batch = lse.size(0); num_heads = lse.size(1); - total_tokens = lse.size(2); + lse_seqlen = lse.size(2); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); - shape = {batch, num_heads, total_tokens / 2}; + shape = {batch, num_heads, second_half_lse_seqlen}; } at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); constexpr unsigned int block = 256; - unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; if (lse_packed) { thd_lse_kernel <<>>( lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, total_tokens); + num_heads, lse_seqlen, second_half_lse_seqlen); } else { thd_lse_kernel <<>>( lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, total_tokens); + num_heads, lse_seqlen, second_half_lse_seqlen); } return half_lse; @@ -1534,23 +1539,25 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ NVTE_CHECK(out_per_step.size(1) == num_heads); NVTE_CHECK(out_per_step.size(2) == dim_per_head); - int batch, lse_seqlen; + int batch, lse_seqlen, lse_per_step_seqlen; if (lse_packed) { batch = cu_seqlens.size(0) - 1; lse_seqlen = lse.size(1); + lse_per_step_seqlen = lse_per_step.size(1); NVTE_CHECK(lse.size(0) == num_heads); NVTE_CHECK(lse_seqlen >= total_tokens); NVTE_CHECK(lse_per_step.size(0) == num_heads); - NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1)); } else { batch = lse.size(0); lse_seqlen = lse.size(2); + lse_per_step_seqlen = lse_per_step.size(2); NVTE_CHECK(lse.size(1) == num_heads); NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1)); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); } @@ -1565,13 +1572,13 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ <<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, - dim_per_head, lse_seqlen); + dim_per_head, lse_seqlen, lse_per_step_seqlen); } else { thd_out_correction_kernel <<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, - dim_per_head, lse_seqlen); + dim_per_head, lse_seqlen, lse_per_step_seqlen); } } From a65ad37e622ad89837b15520b9f2b6c7232d3423 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 10 Jan 2025 16:53:31 -0800 Subject: [PATCH 052/239] [JAX] Test_multiprocessing_encoder with process spawn in bash (#1394) * add test_multiprocessing_encoder with processing spawning in bash --------- Signed-off-by: Phuong Nguyen --- examples/jax/encoder/common.py | 7 ++ examples/jax/encoder/conftest.py | 20 ++++++ .../run_test_multiprocessing_encoder.sh | 17 +++++ .../encoder/test_multiprocessing_encoder.py | 70 ++++++------------- qa/L0_jax_distributed_unittest/test.sh | 2 +- 5 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 examples/jax/encoder/conftest.py create mode 100644 examples/jax/encoder/run_test_multiprocessing_encoder.sh diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index c79fa45239..93dbd408ea 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -12,3 +12,10 @@ def is_bf16_supported(): """Return if BF16 has hardware supported""" gpu_arch = get_device_compute_capability(0) return gpu_arch >= 80 + + +@lru_cache +def is_fp8_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 90 diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py new file mode 100644 index 0000000000..b1648892aa --- /dev/null +++ b/examples/jax/encoder/conftest.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for test_multiprocessing_encoder""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for test_multiprocessing_encoder""" + parser.addoption("--num-process", action="store", default=0) + parser.addoption("--process-id", action="store", default=0) + + +@pytest.fixture(autouse=True) +def multiprocessing_parses(request): + """Fixture for querying num-process and process-id""" + if request.cls: + request.cls.num_process = int(request.config.getoption("--num-process")) + request.cls.process_id = int(request.config.getoption("--process-id")) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh new file mode 100644 index 0000000000..6a1dd96739 --- /dev/null +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i & +done +wait + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i & +done +wait diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index ff6fd4d167..7d2df77b7d 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -3,10 +3,10 @@ # See LICENSE for license information. """Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" import argparse -import multiprocessing as mp import os import unittest from functools import partial +import pytest import flax import jax @@ -21,10 +21,10 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding +from common import is_bf16_supported, is_fp8_supported import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from common import is_bf16_supported os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" @@ -252,7 +252,6 @@ def eval_model( def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) @@ -342,6 +341,9 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + if args.process_id == 0: + nltk.download("punkt_tab") + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) jax.distributed.initialize( @@ -551,69 +553,41 @@ def encoder_parser(args): return parser.parse_args(args) -def query_gpu(q): - """Query GPU info on the system""" - gpu_has_fp8, reason = te.fp8.is_fp8_available() - gpu_has_bf16 = is_bf16_supported() - num_gpu = len(jax.devices()) - q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason]) - - -def unittest_query_gpu(): - r""" - It is only used by TestEncoder. - The `jax.distributed.initialize` must be called before any other JAX or Flax API, - otherwise `jax.local_devices` will be incorrect. - Thus, fork another process to query number of GPUs and FP8 capability. - """ - q = mp.Queue() - p = mp.Process(target=query_gpu, args=(q,)) - p.start() - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get() - p.join() - return num_gpu, gpu_has_fp8, gpu_has_bf16, reason - - +@pytest.mark.usefixtures("multiprocessing_parses") class TestEncoder(unittest.TestCase): """Encoder unittests""" - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() + gpu_has_fp8 = is_fp8_supported() + gpu_has_bf16 = is_bf16_supported() def exec(self, use_fp8): """Run 3 epochs for testing""" - num_gpu = self.num_gpu + args = encoder_parser([]) + + num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 dp_size = num_gpu // tp_size batch_size = 64 // dp_size - arg_list = [] - for i in range(num_gpu): - args = encoder_parser([]) - args.num_process = num_gpu - args.use_fp8 = use_fp8 - args.batch_size = batch_size - args.test_batch_size = batch_size - args.process_id = i - arg_list.append(args) - - with mp.Pool(self.num_gpu) as p: - results = p.map(train_and_evaluate, arg_list) + args.use_fp8 = use_fp8 + args.batch_size = batch_size + args.test_batch_size = batch_size + args.num_process = num_gpu + args.process_id = self.process_id - return results + return train_and_evaluate(args) @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" - results = self.exec(False) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(False) + assert result[0] < 0.45 and result[1] > 0.79 - @unittest.skipIf(not gpu_has_fp8, reason) + @unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8") def test_te_fp8(self): """Test Transformer Engine with FP8""" - results = self.exec(True) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(True) + assert result[0] < 0.45 and result[1] > 0.79 if __name__ == "__main__": diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index f1d1c06d38..947796b029 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh From cbc46531705c4c641c0b4593b8303692bf81b3a4 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 13 Jan 2025 11:27:58 -0800 Subject: [PATCH 053/239] Fix "refractor" typo in the PR template (#1402) Signed-off-by: Sergii Dymchenko --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5fee1b7191..abd8f33ccd 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -11,7 +11,7 @@ Fixes # (issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Infra/Build change -- [ ] Code refractor +- [ ] Code refactoring ## Changes From 240240617267cff76178a7f5da58a93806e5a6d2 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 13 Jan 2025 14:24:08 -0600 Subject: [PATCH 054/239] [PyTorch] Adding TP overlap support for `te.Linear` with `parallel_mode="column"` (#1343) * support AG overlap in sequence-parallel Linear forward and RS overlap in sequence-parallel Linear backward Signed-off-by: Alp Dener * implemented TP overlap support for column-parallel te.Linear Signed-off-by: Alp Dener * fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * improved error messages for internal failure to infer TP overlap options in te.Linear Signed-off-by: Alp Dener * fixed linting errors Signed-off-by: Alp Dener * fixed incorrect TP overlap option asserts Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../distributed/run_layer_with_overlap.py | 62 +++- .../distributed/test_comm_gemm_overlap.py | 24 +- transformer_engine/pytorch/module/linear.py | 347 +++++++++++++----- 3 files changed, 322 insertions(+), 111 deletions(-) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e49174c24f..5a67bd616a 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): kwargs["ub_overlap_ag"] = not reference if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = "row" - kwargs["ub_overlap_rs"] = not reference - kwargs["ub_name"] = "proj" + if config.linear_parallel_mode == "row": + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["ub_overlap_rs"] = not reference + elif config.linear_parallel_mode == "column": + input_shape[0] = config.seq_length // tp_size + args.append(3 * hidden_size) + kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["parallel_mode"] = config.linear_parallel_mode + kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv" else: input_shape[0] = config.seq_length // tp_size - kwargs["ub_bulk_wgrad"] = not reference - kwargs["ub_bulk_dgrad"] = not reference + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference if config.layer_type is te.LayerNormLinear: args.append(3 * hidden_size) kwargs["parallel_mode"] = "column" @@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." ) + parser.add_argument( + "--linear-parallel-mode", + type=str.lower, + default="row", + choices=["row", "column"], + help="Parallel mode for te.Linear.", + ) + parser.add_argument( + "--overlap-rs-dgrad", + action="store_true", + default=False, + help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.", + ) parser.add_argument( "--debug", action="store_true", @@ -230,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Intialize userbuffers + ub_cfgs = None + if opts.overlap_rs_dgrad: + ub_cfgs = { + "proj_dgrad": {"method": "ring_exchange"}, + "qkv_dgrad": {"method": "ring_exchange"}, + } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], WORLD_SIZE, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=ub_cfgs, ) # Initialize the Transformer Engine layer with overlap @@ -314,27 +342,29 @@ def run_fwd_bwd(model, x): ref_grads.append(ref_param.grad) # Make sure we have the same number of gradients - numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") if len(test_grads) != len(ref_grads): - numerics_failed[0] = 1 + num_grads_failed[0] = 1 numerics_info = ( "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + f"expected {len(ref_grads)} but got {len(test_grads)}." ) dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world) # Now validate accuracy - if not bool(numerics_failed.item()): + numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda") + if not bool(num_grads_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): rtol = 0.125 if opts.fp8 else 0.025 atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[0] = int(grad_failed) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()): - break + numerics_failed[i] = int(grad_failed) + return_code = torch.max(numerics_failed) + dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world) + else: + return_code = num_grads_failed te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) @@ -344,7 +374,7 @@ def run_fwd_bwd(model, x): if opts.debug and WORLD_RANK == 0: print("Exiting...\n", end="", flush=True) - return numerics_failed[0].item() + return return_code.item() if __name__ == "__main__": diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 240e396534..c285da7fbd 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -21,8 +21,10 @@ BATCH_SIZE: int = 2 NUM_HEADS: int = 12 HEAD_DIM: int = 64 + +# NOTE: te.Linear is intentionally omitted here and manually added later for testing both +# row and column parallel layouts. TE_LAYERS = [ - te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, @@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] + if layer_type == te.Linear.__name__: + test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") if fp8: if not fp8_available: @@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections): @pytest.mark.parametrize( - "layer_type", - [layer.__name__ for layer in TE_LAYERS], - ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], + "layer_type,linear_parallel_mode", + ( + [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")] + + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) + ), + ids=( + [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "] + + [(" " + layer.__name__ + " ") for layer in TE_LAYERS] + ), ) @pytest.mark.parametrize( "fp8,fp8_init", @@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections): " FP8 GEMM - FP8 PARAMS ", ], ) -def test_layers_with_overlap(layer_type, fp8, fp8_init): +def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5fd4dd2fc9..2262d23832 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Linear API""" +from functools import reduce +from operator import mul as multiply_op from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -43,7 +45,7 @@ fp8_cast_transpose_fused, cast_to_fp8, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor @@ -80,8 +82,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, ub_name: str, fp8_output: bool, fsdp_group: Union[dist_group_type, None], @@ -99,7 +105,8 @@ def forward( assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs + ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop + ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop # Cast input to expected dtype inputmat = cast_if_needed(inputmat, activation_dtype) @@ -150,10 +157,11 @@ def forward( inputmat_scale_inv.fill_(inputmat_scale_inv.item()) # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: + if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat + if fp8: bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -165,75 +173,92 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], fp8_dtype_forward, torch.uint8, ) else: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( None, None, None, activation_dtype, ) + ub_obj = None ub_algo = None rs_out = None - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) + inputmat_data = ( + inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total + ) + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + out = ub_obj.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): - if ub_obj_projout.is_atomic_gemm(): + if ub_obj.is_p2p_overlap(): + if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - if ub_obj_projout.is_atomic_gemm(): + if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj_projout.is_fp8_ubuf(): - proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT + if ub_obj.is_fp8_ubuf(): + out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] - proj_out_tetype = fp8_dtype_forward - proj_out_pttype = torch.uint8 - ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) + out_tedtype = fp8_dtype_forward + out_pttype = torch.uint8 + ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer." + ub_obj.copy_input_to_ubuf(inputmat_data, True) + ub_obj.set_ubuf_scale_inv(inputmat_scale_inv) + if ub_obj.is_atomic_gemm(): + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + out_tedtype = TE_DType[activation_dtype] + out_pttype = activation_dtype + dim_size = list(inputmat_total.size()) + dim_size[0] *= tp_size + dim_size[1] = out_features + out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) + else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features - out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) + out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) _ = fp8_gemm( weight_fp8._data, weight_fp8._scale_inv, 0, weight_fp8._fp8_dtype, - ( - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) - else inputmat_total - ), + inputmat_data, inputmat_scale_inv, 0, fp8_dtype_forward, - proj_out_pttype, + out_pttype, get_workspace(), bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=proj_out_index, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out_index=out_index, fp8_meta_tensor=meta_tensor, - D_dtype=proj_out_tetype, + D_dtype=out_tedtype, ) if fp8_output: out = Float8Tensor( @@ -261,17 +286,30 @@ def forward( -amin, amax ).float() - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) + ub_obj = None + ub_algo = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + out = ub_obj.get_ubuf_output(1) dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) + dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): + if ub_obj.is_p2p_overlap(): ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_obj.copy_input_to_ubuf(inputmat_total, True) + dim_size = list(inputmat_total.size()) + dim_size[0] *= tp_size # all-gathered sequence length + dim_size[1] = out_features + out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features @@ -285,9 +323,9 @@ def forward( bias=bias, use_bias=use_bias, out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, ) if is_grad_enabled: @@ -343,7 +381,10 @@ def forward( ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad @@ -356,12 +397,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if ub_overlap_rs: - out = rs_out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if parallel_mode == "row": + if ub_overlap_rs_fprop: + out = rs_out + elif sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp_shape[1:-1], out_features) @@ -401,15 +443,75 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], tp_world_size = get_distributed_world_size(ctx.tp_group) ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ub_algo = None + ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad + + ctx.ub_obj_gradout = None + ub_obj_wgrad = None + ub_algo_wgrad = None + ub_algo_dgrad = None + rs_out = None + dgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] * tp_world_size + # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + dgrad = ctx.ub_obj_gradout.get_ubuf_output(1) + if ctx.ub_obj_gradout.is_p2p_overlap(): + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + inputmat_data = ( + inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat + ) + ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) + inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) + if isinstance(inputmat, Float8Tensor): + inputmat._data = inputmat_ubuf + else: + inputmat = inputmat_ubuf + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + dgrad = ub_obj_wgrad.get_ubuf_output(1) + + if dgrad is None: + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + dgrad_shape[0] = dgrad_shape[0] * tp_world_size + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) ( grad_output, @@ -420,13 +522,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx, grad_output, ctx.parallel_mode == "row" ) - # Column Parallel Linear - # Overlap input AG with dgrad + # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers) inputmat_total = None inputmat_t_total = None - handle = None - if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel: - inputmat_total, handle = gather_along_first_dim( + inputmat_gather_handle = None + if ( + weight.requires_grad + and ctx.parallel_mode == "column" + and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad + ): + inputmat_total, inputmat_gather_handle = gather_along_first_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad ) else: @@ -446,13 +552,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - if ctx.is_input_fp8: + if ctx.is_input_fp8 or ( + ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf() + ): out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, ctx.fp8_meta["scaling_bwd"], fp8_dtype_backward, torch.uint8, ) + if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf(): + ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: out_index, meta_tensor, output_te_dtype, output_dtype = ( None, @@ -460,7 +570,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, ctx.activation_dtype, ) - dgrad, _ = fp8_gemm( + _ = fp8_gemm( weight_fp8.transpose_2d(), weight_fp8._scale_inv, 0, @@ -472,12 +582,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], output_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub_algo=ub_algo_dgrad, + ub=ctx.ub_obj_gradout, + out=dgrad, out_index=out_index, fp8_meta_tensor=meta_tensor, D_dtype=output_te_dtype, + extra_output_tensor=rs_out, ) + + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + if output_dtype == torch.uint8: dgrad = Float8Tensor( data=dgrad, @@ -488,30 +604,34 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - dgrad, _, _ = gemm( + _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), layout="NN", grad=True, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - if ctx.ub_overlap_ag - else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub_algo=ub_algo_dgrad, + ub=ctx.ub_obj_gradout, + out=dgrad, + extra_output_tensor=rs_out, ) - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + + if inputmat_gather_handle is not None: + inputmat_gather_handle.wait() + + # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers) + dgrad_reduce_handle = None + if ctx.requires_dgrad and ctx.parallel_mode == "column": + if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad): + dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim( dgrad, ctx.tp_group, async_op=True ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel and not ctx.sequence_parallel: + dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) wgrad = None if weight.requires_grad: @@ -548,6 +668,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: wgrad, _, _ = gemm( @@ -559,6 +681,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: # WGRAD @@ -572,15 +696,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_ubuf_output(0) + # Deallocate input tensor clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_t_total) - # Column Parallel Linear - if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: - handle.wait() + # Wait for dgrad reduce-scatter or all-reduce + if dgrad_reduce_handle is not None: + dgrad_reduce_handle.wait() if not ctx.use_bias: grad_bias = None @@ -634,8 +763,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # ub_overlap_rs - None, # ub_overlap_ag + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fp8_output None, # fsdp_group @@ -729,8 +862,10 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -742,13 +877,6 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag - if ub_overlap_rs or ub_overlap_ag: - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name - self.get_rng_state_tracker = get_rng_state_tracker - self.rng_tracker_name = rng_tracker_name if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -773,6 +901,45 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag + self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs + self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad + self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad + if self.ub_overlap_rs_dgrad: + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs + self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." + self.ub_name = ub_name + + assert not ( + self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop + ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." + assert not ( + self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad + ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." + assert not ( + self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) + ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time." + + self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1017,8 +1184,12 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - self.ub_overlap_rs, - self.ub_overlap_ag, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.ub_name, fp8_output, self.fsdp_group, From 3d63cbb469ace2d1ad7798a956e16dee42bd655b Mon Sep 17 00:00:00 2001 From: guyueh1 <140554423+guyueh1@users.noreply.github.com> Date: Wed, 15 Jan 2025 10:12:25 -0800 Subject: [PATCH 055/239] Make it an option to compile activation functions with fast math (#1410) * Add a compile option to compile activation kernels with fast math Signed-off-by: Guyue Huang * Fix Signed-off-by: Guyue Huang * Apply suggestions from code review Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: guyueh1 <140554423+guyueh1@users.noreply.github.com> --------- Signed-off-by: Guyue Huang Signed-off-by: guyueh1 <140554423+guyueh1@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- setup.py | 3 +++ transformer_engine/common/CMakeLists.txt | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/setup.py b/setup.py index 16e988aa88..643dd7a908 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,9 @@ def setup_common_extension() -> CMakeExtension: ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): + cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + # Project directory root root_path = Path(__file__).resolve().parent diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3efe116105..3afddcc48d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -147,6 +147,14 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") +option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) +if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) + set_source_files_properties(activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + PROPERTIES + COMPILE_OPTIONS "--use_fast_math") +endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") From c2937c5abacb85326f093e74bb282fb491b30b3d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 16 Jan 2025 14:32:50 -0600 Subject: [PATCH 056/239] [PyTorch] `te.Linear` FP8 DGRAD+RS output bugfix (#1412) * corrected RS overlap BF16 output clashing with Float8Tensor constructor Signed-off-by: Alp Dener * fixed empty dgrad buffer dtype at initialization Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/module/linear.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2262d23832..5893c4ea3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -506,13 +506,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") dgrad = ub_obj_wgrad.get_ubuf_output(1) - if dgrad is None: - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - dgrad_shape[0] = dgrad_shape[0] * tp_world_size - dgrad = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device - ) - ( grad_output, grad_output_c, @@ -550,6 +543,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + output_dtype = ctx.activation_dtype if ctx.requires_dgrad: if ctx.fp8: if ctx.is_input_fp8 or ( @@ -570,6 +564,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, ctx.activation_dtype, ) + + if dgrad is None: + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + dgrad_shape[0] = dgrad_shape[0] * tp_world_size + dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device) + + if ctx.requires_dgrad: + if ctx.fp8: _ = fp8_gemm( weight_fp8.transpose_2d(), weight_fp8._scale_inv, @@ -593,8 +595,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.ub_overlap_rs_dgrad: dgrad = rs_out - - if output_dtype == torch.uint8: + elif output_dtype == torch.uint8: dgrad = Float8Tensor( data=dgrad, fp8_meta=ctx.fp8_meta, From 6e848924fe5c81a233f501cad2f02638b2aa41e4 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Thu, 16 Jan 2025 20:08:48 -0800 Subject: [PATCH 057/239] [JAX] Consolidate the distributed fused attention test code (#1405) Consolidate the distributed fused attention tests to shared input generation and execition logic. Signed-off-by: Michael Goldfarb --- tests/jax/distributed_test_base.py | 29 +- tests/jax/test_distributed_fused_attn.py | 444 +++++------------------ tests/jax/test_fused_attn.py | 242 ++++++++++-- 3 files changed, 323 insertions(+), 392 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index c2d7039a53..d0ace8263f 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -18,14 +18,22 @@ def generate_configs(): configs = [] if is_devices_enough(2): - configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")]) - configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")]) + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2") + ) if is_devices_enough(4): - TP_size = 2 - DP_size = 2 configs.append( - [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")] + pytest.param( + 4, + (2, 2), + ("dp", "tp"), + MeshResource(dp_resource="dp", tp_resource="tp"), + id=f"n4_dp2_tp2", + ) ) return configs @@ -33,7 +41,8 @@ def generate_configs(): def generate_context_parallel_configs(): configs = [] - + mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") + axes = ("dp", "cp", "tp") DP_sizes = (1, 2) CP_sizes = (1, 2, 4, 8) TP_sizes = (1, 2) @@ -41,13 +50,7 @@ def generate_context_parallel_configs(): ndev = cp * tp * dp if is_devices_enough(ndev): configs.append( - pytest.param( - ndev, - (dp, cp, tp), - ("dp", "cp", "tp"), - MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), - id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", - ) + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") ) return configs diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5a41911691..2e15dd4d5d 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -37,8 +37,7 @@ ) from transformer_engine.jax.sharding import MeshResource -# We will use the golden reference model from our non distributed attention test fixture. -from test_fused_attn import general_dot_product_attention, make_mask +from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask DTYPES = [jnp.float16, jnp.bfloat16] @@ -49,7 +48,7 @@ def generate_collectives_count_ref( self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) - _, seqlen, _, heads, _ = shape + _, seqlen, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None tp_size = 1 if mesh_resource.tp_resource is not None: @@ -62,45 +61,28 @@ def generate_collectives_count_ref( # for loss and dbias return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype): - batch, seqlen, _, heads, _ = shape - - qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - - bias = ( - random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) - if with_bias - else None - ) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - qkv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - bias_pspec = ( - PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize( - "attn_bias_type", - [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS], + "data_shape", + [ + pytest.param((32, 512, 12, 64), id="32-512-12-64"), + pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), + ], ) @pytest.mark.parametrize( - "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + ], ) @pytest.mark.parametrize("dtype", DTYPES) def test_self_attn( @@ -111,14 +93,14 @@ def test_self_attn( mesh_resource, data_shape, attn_bias_type, + bias_shape, attn_mask_type, dtype, ): dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, _, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -136,74 +118,36 @@ def test_self_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(qkv, bias, mask): - return jnp.mean( - fused_attn( - (qkv,), - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BS3HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ) - ) - - def ref_func(qkv, bias, mask): - query, key, value = jnp.split(qkv, [1, 2], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output).astype(dtype) - - with_bias = attn_bias_type != AttnBiasType.NO_BIAS - (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, with_bias, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref( + mesh_shape, + mesh_axes, + mesh_resource, + attn_bias_type != AttnBiasType.NO_BIAS, + data_shape, + dtype, ) - collective_count_ref = self.generate_collectives_count_ref( - mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BS3HD, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) - bias_ = ( - jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias - ) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - grad_args = (0, 1) if with_bias else (0,) - out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) - - compare_ops( - target_func, - ref_func, - [qkv_, bias_, mask_], - collective_count_ref, - grad_args=grad_args, - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(qkv_pspec, bias_pspec, mask_pspec), - out_shardings=(None, out_grad_shardings), - ) + runner.test_backward() class TestDistributedCrossAttn: @@ -213,31 +157,6 @@ def generate_collectives_count_ref(self): all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): - batch, seqlen, heads, hidden = shape - - q = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) - - kv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize( @@ -248,11 +167,11 @@ def test_cross_attn( self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -270,67 +189,29 @@ def test_cross_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(q, kv, mask): - return jnp.mean( - fused_attn( - (q, kv), - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BSHD_BS2HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ), - dtype=jnp.float32, - ) - - def ref_func(query, kv, mask): - key, value = jnp.split(kv, [1], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=None, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output, dtype=jnp.float32) - - (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref() + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BSHD_BS2HD, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - collective_count_ref = self.generate_collectives_count_ref() - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) - kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - compare_ops( - target_func, - ref_func, - [q_, kv_, mask_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(q_pspec, kv_pspec, mask_pspec), - out_shardings=(None, (q_pspec, kv_pspec)), - ) + runner.test_backward() @pytest.mark.parametrize( @@ -366,41 +247,6 @@ def ref_func(query, kv, mask): ) class TestDistributedContextParallelSelfAttn: - def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): - batch, seqlen, heads, hidden = shape - kv_shape = (batch, seqlen, heads // kv_groups, hidden) - qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) - q = random.normal(qkey, shape, dtype=dtype) - k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - - def gen_valid(bs, max_seqlen, pad_ratio): - pad_len = int(max_seqlen * pad_ratio) - valid_len = max_seqlen - pad_len - tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) - return tokens, jnp.logical_not(tokens) - - from test_fused_attn import make_mask - - q_idx, _ = gen_valid(batch, seqlen, 0.0) - kv_idx, _ = gen_valid(batch, seqlen, 0.0) - mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type) - - return q, k, v, mask - - def qkv_to_layout(self, q, k, v, qkv_layout): - qkv_args = () - match qkv_layout: - case QKVLayout.BSHD_BS2HD: - k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) - kv = jnp.concatenate((k, v), axis=-3) - qkv_args = (q, kv) - case QKVLayout.BSHD_BSHD_BSHD: - qkv_args = (q, k, v) - case _: - raise ValueError(f"Unsupported {qkv_layout=}") - return qkv_args - def impl_test_context_parallel_attn( self, device_count, @@ -416,6 +262,7 @@ def impl_test_context_parallel_attn( cp_strategy, ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape @@ -431,6 +278,29 @@ def impl_test_context_parallel_attn( num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_kv_heads, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + cp_strategy=cp_strategy, + cp_load_balanced=load_balanced, + ) + def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( dtype, @@ -465,123 +335,7 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - def target_func(q, k, v, mask): - return fused_attn( - self.qkv_to_layout(q, k, v, qkv_layout), - None, # bias - mask, - None, # seed - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - context_parallel_strategy=cp_strategy, - context_parallel_causal_load_balanced=load_balanced, - context_parallel_axis="cp", - ).astype(dtype) - - def ref_func(q, k, v, mask): - output = general_dot_product_attention( - q, - k, - v, - bias=None, - mask=mask, - deterministic=not is_training, - scale_factor=scaling_factor, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - return output.astype(dtype) - - def grad_func(func, *args, **kwargs): - # Gradient is small, use a gradient multiplier to amplify the gradient - _, max_seq_len, num_heads, _ = data_shape - gradient_multiplier = max_seq_len * num_heads - if attn_mask_type.is_causal(): - gradient_multiplier /= 10 - ret_valid = func(*args, **kwargs) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) - - q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) - - diff_argnums = (0, 1, 2) - - # Single GPU (reference) - ref_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums - ) - ) - ref_fwd, ref_grads = ref_func_jit(q, k, v, mask) - - # Multi GPU (function under test) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False): - qkv_ps = PartitionSpec( - mesh_resource.dp_resource, - mesh_resource.cp_resource, - mesh_resource.tp_resource, - None, - ) - qkv_sharding = NamedSharding(mesh, qkv_ps) - - mask_ps = PartitionSpec( - mesh_resource.dp_resource, None, mesh_resource.cp_resource, None - ) - mask_sharding = NamedSharding(mesh, mask_ps) - - reorder = partial( - reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - inverse_reorder = partial( - inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - - if load_balanced: - q, k, v = jax.tree.map(reorder, (q, k, v)) - - q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) - mask_ = jax.device_put(mask, device=mask_sharding) - - target_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(target_func, q, k, v, mask), - argnums=diff_argnums, - ), - in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], - out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), - ) - - target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) - - if load_balanced: - target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) - target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) - - has_diffs = False - - print_debug_tensor_stats("target", target_fwd) - print_debug_tensor_stats("ref", ref_fwd) - print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd)) - assert_allclose(target_fwd, ref_fwd, dtype=dtype) - - for i in range(len(target_grads)): - if ref_grads[i] is None or target_grads[i] is None: - # expect both none if one is - assert target_grads[i] is None and ref_grads[i] is None - else: - print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i]) - print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i]) - print_debug_tensor_stats( - f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i]) - ) - - assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) + runner.test_backward() def test_context_parallel_allgather_attn( self, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 5cbbec7b04..710ae1946d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -3,10 +3,10 @@ # See LICENSE for license information. """Tests for fused attention""" from enum import Enum -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from math import sqrt -from typing import Tuple, Optional +from typing import Tuple, Optional, Dict import random import jax @@ -19,16 +19,22 @@ from flax.linen.dtypes import promote_dtype from jax import Array from jax import value_and_grad, jit +from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, QKVLayout, QKVFormat, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, fused_attn, fused_attn_thd, make_swa_mask, + CPStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.transformer_engine_jax import ( @@ -36,7 +42,8 @@ get_cudnn_version, ) -from utils import assert_allclose +from distributed_test_base import assert_equal_collectives +from utils import assert_allclose, print_debug_tensor_stats @pytest.fixture(autouse=True, scope="module") @@ -304,6 +311,19 @@ class FusedAttnRunner: bias_shape: BiasShape window_size: Optional[Tuple[int, int]] = None + # Specifies sharding resources for distributed tests + number_of_devices: int = 1 + mesh_shape: tuple[int, ...] = (1, 1, 1) + mesh_axes: tuple[str, ...] = ("dp", "cp", "tp") + mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp")) + + # Context parallel aux arguments + cp_strategy: CPStrategy = CPStrategy.DEFAULT + cp_load_balanced: bool = True + + # dictionary of expected collective comm bytes + coll_count_ref: Optional[Dict[str, int]] = None + # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): @@ -362,6 +382,14 @@ def _check_configs(self): def _setup_inputs(self): self._check_configs() + + # Create a mesh for distributed tests + self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape) + self.mesh = Mesh(self.devices, self.mesh_axes) + self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) + self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) + self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1) + key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) @@ -527,6 +555,66 @@ def generate_random_segment_ids( self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) + # Setup distributed sharding specs + # Setup shardings for distributed tests + self.qkvo_psec = PartitionSpec( + self.mesh_resource.dp_resource, + self.mesh_resource.cp_resource, + self.mesh_resource.tp_resource, + None, + ) + self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) + + self.mask_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec) + + if self.bias_shape == BiasShape._1HSS: + self.bias_pspec = PartitionSpec( + None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._B1SS: + self.bias_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._11SS: + self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + else: + self.bias_pspec = PartitionSpec() + self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) + + self.dropout_rng_pspec = PartitionSpec( + None, + ) + self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec) + + self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec) + + # [batch][max_segments_per_batch] + # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset. + self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) + self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) + + # Softmax aux sharding + + if self.cp_size > 1 and self.cp_load_balanced: + self.cp_reorder_fn = partial( + reorder_causal_load_balancing, + cp_size=self.cp_size, + tensor_format=self.qkv_layout.get_qkv_format(), + ) + self.cp_inverse_reorder_fn = partial( + inverse_reorder_causal_load_balancing, + cp_size=self.cp_size, + tensor_format=self.qkv_layout.get_qkv_format(), + ) + else: + # no-ops for non cp or non load balanced + self.cp_reorder_fn = lambda x: x + self.cp_inverse_reorder_fn = lambda x: x + def test_forward(self): """ Test forward without JIT @@ -534,17 +622,21 @@ def test_forward(self): self._setup_inputs() args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] + customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # Put test data onto each GPU for distributed. + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.mask_for_customcall, self.mask_sharding), + jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), + jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), + jax.device_put(self.offsets_q, self.seq_length_offset_sharding), + jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -555,10 +647,31 @@ def test_forward(self): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } - # Convert the outputs to float32 for the elementwise comparison - primitive_out = customcall_fused_dpa(*customcall_args, **kwargs) + customcall_fused_dpa_jit = jit( + partial(customcall_fused_dpa, **kwargs), + static_argnames=kwargs.keys(), + in_shardings=[ + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.mask_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.dropout_rng_sharding, + ], + ) + + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out = customcall_fused_dpa_jit(*customcall_args) + primitive_out = self.cp_inverse_reorder_fn(primitive_out) + reference_out = jax_dpa(*args, **kwargs) if self.is_training and self.dropout_prob > 0.0: @@ -571,9 +684,19 @@ def test_forward(self): assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = ( + customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() + ) + assert_equal_collectives(target_hlo, self.coll_count_ref) + def test_backward(self): """ - Test value_and_grad with JIT, which includes both forward and backward + Test value_and_grad with JIT, which includes both forward and backward. + + If coll_count_ref is not None then the HLO of the backwrds function + HLO will be examined for the expected comms. """ self._setup_inputs() @@ -587,20 +710,24 @@ def grad_func(func, *args, **kwargs): ret_valid = jnp.where( self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs) ) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.mask_for_customcall, self.mask_sharding), + jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), + jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), + jax.device_put(self.offsets_q, self.seq_length_offset_sharding), + jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -611,10 +738,22 @@ def grad_func(func, *args, **kwargs): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -623,7 +762,20 @@ def grad_func(func, *args, **kwargs): customcall_fused_dpa, q, k, v, bias, *args, **kwargs ), arg_nums, - ) + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.mask_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), ) jitted_reference = jit( value_and_grad( @@ -632,20 +784,31 @@ def grad_func(func, *args, **kwargs): ) ) - primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + reference_out, reference_dgrad = jitted_reference(*args) # Skip elementwise comparison when dropout enabled if self.dropout_prob > 0.0: return + print_debug_tensor_stats(f"primitive_out", primitive_out) + print_debug_tensor_stats(f"reference_grad_valid", reference_out) + print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out)) assert_allclose(primitive_out, reference_out, dtype=self.dtype) - def check_dqkv(primitive, reference, pad): + def check_dqkv(primitive, reference, pad, idx): primitive_valid, primitive_invalid, reference_valid, reference_invalid = ( _split_valid_and_invalid(primitive, reference, pad) ) + print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx]) + print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx]) + print_debug_tensor_stats( + f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx]) + ) + assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) @@ -653,11 +816,17 @@ def check_dqkv(primitive, reference, pad): primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3] reference_dq, reference_dk, reference_dv = reference_dgrad[:3] - check_dqkv(primitive_dq, reference_dq, self.pad_q) - check_dqkv(primitive_dk, reference_dk, self.pad_kv) - check_dqkv(primitive_dv, reference_dv, self.pad_kv) + primitive_dq = self.cp_inverse_reorder_fn(primitive_dq) + primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) + primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) + check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) + check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: + # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation. + primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -685,6 +854,11 @@ def check_dqkv(primitive, reference, pad): dtype=self.dtype, ) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() + assert_equal_collectives(target_hlo, self.coll_count_ref) + @pytest.mark.parametrize( "attn_mask_type", From 7aa81186f95a0aea4c916fae4cf0efd05054bb62 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 22 Jan 2025 02:21:27 +0800 Subject: [PATCH 058/239] [PyTorch] Fix AttentionParams comparison logic (#1397) only compare the recipe in AttentionParams.fp8_meta Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 55c8a2fcf2..f2120f3a73 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -303,6 +303,24 @@ class AttentionParams: fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + def __eq__(self, other): + """ + Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared, + since all other entries of fp8_meta are unused in get_attention_backend. + """ + if not isinstance(other, self.__class__): + return NotImplemented + for field in fields(self): + fname = field.name + sf = getattr(self, fname) + of = getattr(other, fname) + if fname != "fp8_meta": + if sf != of: + return False + elif sf["recipe"] != of["recipe"]: + return False + return True + _alibi_cache = { "_num_heads": None, From 3d7ff1c63aefcdd47ad011af95f5f9bae38679c4 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 22 Jan 2025 13:06:14 -0800 Subject: [PATCH 059/239] [PyTorch] Avoid `parameters` function in op backward pass (#1403) * Avoid `parameters` function in op backward pass Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/ops/fuser.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index dc96c12523..7c638032f1 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -192,7 +192,7 @@ def forward( func_ctx.backward_ops = backward_ops func_ctx.basic_ops = basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.num_params = num_params + func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -258,14 +258,14 @@ def backward( # Flatten list of parameter gradients grad_params_flat = [] for idx, dparams in enumerate(grad_params): - params = list(basic_ops[idx].parameters()) + num_params = func_ctx.basic_op_num_params[idx] if dparams is None: - dparams = [None for _ in range(len(params))] + dparams = [None for _ in range(num_params)] else: dparams = list(dparams) - if len(dparams) != len(params): + if len(dparams) != num_params: raise RuntimeError( - f"Expected op {idx} to generate {len(params)} param grads, " + f"Expected op {idx} to generate {num_params} param grads, " f"but got {len(dparams)}" ) grad_params_flat.extend(dparams) From c2c3d540b1eca9ccbcc0fa7cb871688814a536f9 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 24 Jan 2025 13:50:52 +0800 Subject: [PATCH 060/239] [JAX] Support segment_ids/pos as FA inputs (#1406) * POC for segment_ids/segment_pos Signed-off-by: Reese Wang * Change segment_pos position Signed-off-by: Reese Wang * Use RemainingArgs to solve number of parameters mismatches Signed-off-by: Reese Wang * Test mask_descriptor for accomendating different mask representations Signed-off-by: Reese Wang * Fix bugs Signed-off-by: Reese Wang * Use descriptor in bwd Signed-off-by: Reese Wang * Primitives only accepts pure jnp array Signed-off-by: Reese Wang * segment_ids/pos support POC Signed-off-by: Reese Wang * Move seqlens/offsets generation to mask descriptor Signed-off-by: Reese Wang * Rename MaskDescriptor to SequenceDescriptor Signed-off-by: Reese Wang * Generalize get_seqlens_and_offsets Signed-off-by: Reese Wang * Utilize sequence desc on FA bwd Signed-off-by: Reese Wang * Migrate to new API Signed-off-by: Reese Wang * Add docstrings Signed-off-by: Reese Wang * Remove small inputs and test different input format Signed-off-by: Reese Wang * Fix lint Signed-off-by: Reese Wang * Fix seed shardings Signed-off-by: Reese Wang * Optimize sequence converting overhead Signed-off-by: Reese Wang * Optimize seq_offsets calculation Signed-off-by: Reese Wang * Fix up Signed-off-by: Reese Wang * fix lint Signed-off-by: Reese Wang * Fix conflicts Signed-off-by: Reese Wang * Remove reduntant line Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_distributed_fused_attn.py | 21 +- tests/jax/test_fused_attn.py | 160 +++-- transformer_engine/jax/attention.py | 554 ++++++++++++++---- .../jax/cpp_extensions/attention.py | 290 ++++++--- .../jax/csrc/extensions/attention.cpp | 58 +- 5 files changed, 786 insertions(+), 297 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 2e15dd4d5d..d7e015dbf7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -2,31 +2,18 @@ # # See LICENSE for license information. -import pytest -from functools import partial - import jax import jax.numpy as jnp import numpy as np -from flax.linen import dot_product_attention from jax import random -from jax.sharding import Mesh, NamedSharding, PartitionSpec from distributed_test_base import ( generate_configs, generate_context_parallel_configs, generate_collectives_count, - compare_ops, -) -from utils import ( - make_causal_mask, - make_self_mask, - assert_allclose, - print_debug_tensor_stats, ) from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, - fused_attn, AttnBiasType, AttnMaskType, QKVLayout, @@ -36,10 +23,11 @@ CPStrategy, ) from transformer_engine.jax.sharding import MeshResource +import pytest -from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask +from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat -DTYPES = [jnp.float16, jnp.bfloat16] +DTYPES = [jnp.bfloat16] class TestDistributedSelfAttn: @@ -141,6 +129,7 @@ def test_self_attn( QKVLayout.BS3HD, bias_shape, None, + SeqDescFormat.Seqlens, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, @@ -205,6 +194,7 @@ def test_cross_attn( QKVLayout.BSHD_BS2HD, bias_shape, None, + SeqDescFormat.Seqlens, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, @@ -293,6 +283,7 @@ def impl_test_context_parallel_attn( qkv_layout, bias_shape, None, + SeqDescFormat.Seqlens, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 710ae1946d..beaf18cea3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. """Tests for fused attention""" -from enum import Enum +from enum import Enum, auto from dataclasses import dataclass, field from functools import partial from math import sqrt @@ -28,12 +28,11 @@ AttnBiasType, AttnMaskType, QKVLayout, - QKVFormat, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, - fused_attn_thd, make_swa_mask, + SequenceDescriptor, CPStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -199,8 +198,8 @@ def _find_offsets(x): ).squeeze(-1) offsets = _find_offsets(segment_ids) - offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - seqlens = jnp.insert(seqlens, -1, values=0, axis=-1) + offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1) + seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1) seqlens = jnp.where(seqlens, seqlens, -1) return seqlens, offsets @@ -239,11 +238,7 @@ def customcall_fused_dpa( key, value, bias, - mask, - seqlens_q, - seqlens_kv, - offsets_q, - offsets_kv, + sequence_descriptor, dropout_rng, **kwargs, ): @@ -264,19 +259,9 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not qkv_layout.is_thd(): - kwargs.pop("max_segments_per_seq") - return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) - return fused_attn_thd( - qkv_args, - bias, - seqlens_q, - seqlens_kv, - offsets_q, - offsets_kv, - dropout_rng, - **kwargs, - ).astype(query.dtype) + return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype( + query.dtype + ) class BiasShape(Enum): @@ -290,6 +275,12 @@ class BiasShape(Enum): _11SS = "11SS" +class SeqDescFormat(Enum): + Mask = auto() + Seqlens = auto() + SegmentIDs = auto() + + @dataclass class FusedAttnRunner: """ @@ -309,7 +300,8 @@ class FusedAttnRunner: is_training: bool qkv_layout: QKVLayout bias_shape: BiasShape - window_size: Optional[Tuple[int, int]] = None + window_size: Tuple[int, int] + seq_desc_format: SeqDescFormat # Specifies sharding resources for distributed tests number_of_devices: int = 1 @@ -327,11 +319,14 @@ class FusedAttnRunner: # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): - if 90400 <= get_cudnn_version() < 90500: - return self.num_segments_per_seq + if self.qkv_layout.is_thd(): + if 90400 <= get_cudnn_version() < 90500: + return self.num_segments_per_seq + else: + # +1 for testing runtime_segments < max_segments + return self.num_segments_per_seq + 1 else: - # +1 for testing runtime_segments < max_segments - return self.num_segments_per_seq + 1 + return 1 def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available @@ -462,11 +457,11 @@ def generate_random_segment_ids( ): rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad - segment_ids = np.zeros((batch_size, sequence_length), dtype=int) - segment_pos = np.zeros((batch_size, sequence_length), dtype=int) + segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32) + segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32) # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad - segment_pad = np.zeros((batch_size, sequence_length), dtype=int) + segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32) # Not include paddings max_segment_size = sequence_length // num_segments @@ -541,16 +536,47 @@ def generate_random_segment_ids( self.window_size, ) + # Test different input formats if self.qkv_layout.is_thd(): - self.mask_for_customcall = None # THD format doesn't support mask + match self.seq_desc_format: + case SeqDescFormat.Mask: + pytest.skip("THD doesn't support mask input") + case SeqDescFormat.Seqlens: + self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets( + (self.seqlens_q, self.seqlens_kv), + (self.offsets_q, self.offsets_kv), + ) + case SeqDescFormat.SegmentIDs: + self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( + (self.segment_ids_q, self.segment_ids_kv), + (self.segment_pos_q, self.segment_pos_kv), + ) + case _: + raise ValueError(f"Unknown {self.seq_desc_format=}") else: - self.mask_for_customcall = make_mask( - self.segment_ids_q, - self.segment_ids_kv, - self.segment_pos_q, - self.segment_pos_kv, - self.attn_mask_type, - ) + match self.seq_desc_format: + case SeqDescFormat.Mask: + self.sequence_desciptor = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) + case SeqDescFormat.Seqlens: + self.sequence_desciptor = SequenceDescriptor.from_seqlens( + ( + self.segment_ids_q.sum(axis=-1).astype(jnp.int32), + self.segment_ids_kv.sum(axis=-1).astype(jnp.int32), + ), + ) + case SeqDescFormat.SegmentIDs: + self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( + (self.segment_ids_q, self.segment_ids_kv), + None, + ) + case _: + raise ValueError(f"Unknown {self.seq_desc_format=}") self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) @@ -565,10 +591,21 @@ def generate_random_segment_ids( ) self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) - self.mask_pspec = PartitionSpec( + mask_pspec = PartitionSpec( self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None ) - self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec) + self.mask_sharding = NamedSharding(self.mesh, mask_pspec) + + match self.seq_desc_format: + case SeqDescFormat.Mask: + self.seq_desc_sharding = self.mask_sharding + case _: + + def to_dp_shardings(x): + pspec = PartitionSpec(self.mesh_resource.dp_resource) + return NamedSharding(self.mesh, pspec) + + self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) if self.bias_shape == BiasShape._1HSS: self.bias_pspec = PartitionSpec( @@ -631,11 +668,7 @@ def test_forward(self): jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.bias, self.bias_sharding), - jax.device_put(self.mask_for_customcall, self.mask_sharding), - jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), - jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), - jax.device_put(self.offsets_q, self.seq_length_offset_sharding), - jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { @@ -659,11 +692,7 @@ def test_forward(self): self.qkvo_sharding, self.qkvo_sharding, self.bias_sharding, - self.mask_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, + self.seq_desc_sharding, self.dropout_rng_sharding, ], ) @@ -722,11 +751,7 @@ def grad_func(func, *args, **kwargs): jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.bias, self.bias_sharding), - jax.device_put(self.mask_for_customcall, self.mask_sharding), - jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), - jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), - jax.device_put(self.offsets_q, self.seq_length_offset_sharding), - jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { @@ -768,11 +793,7 @@ def grad_func(func, *args, **kwargs): self.qkvo_sharding, self.qkvo_sharding, self.bias_sharding, - self.mask_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, + self.seq_desc_sharding, self.dropout_rng_sharding, ), out_shardings=(None, grad_shardings), @@ -883,10 +904,7 @@ def check_dqkv(primitive, reference, pad, idx): @pytest.mark.parametrize( "b, s_q, s_kv, h_q, h_kv, d, dtype", [ - pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), - pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), - pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"), pytest.param( 2, 2048, @@ -897,8 +915,8 @@ def check_dqkv(primitive, reference, pad, idx): jnp.bfloat16, id="2-2048-1024-12-12-64-BF16-CROSS", ), - pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), + pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), ], ) @pytest.mark.parametrize( @@ -915,6 +933,14 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(True, id="SWA"), ], ) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Mask, id="Mask"), + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), + ], +) class TestFusedAttn: """ Fused attention tester @@ -953,6 +979,7 @@ def _test_forward( qkv_layout, bias_shape, swa, + seq_desc_format, ): """ Test forward with parameterized configs @@ -977,6 +1004,7 @@ def _test_forward( qkv_layout, bias_shape, window_size, + seq_desc_format, ) runner.test_forward() @@ -1002,6 +1030,7 @@ def test_backward( qkv_layout, bias_shape, swa, + seq_desc_format, ): """ Test backward with parameterized configs @@ -1024,5 +1053,6 @@ def test_backward( qkv_layout, bias_shape, window_size, + seq_desc_format, ) runner.test_backward() diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 7b6c605236..09128b013b 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -2,13 +2,16 @@ # # See LICENSE for license information. """JAX multi-head attention modules""" - +from __future__ import annotations from enum import Enum from functools import partial -from typing import Optional, Tuple +from typing import Optional, Tuple, Union +import warnings + from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp +from flax.linen import make_attention_mask from transformer_engine.transformer_engine_jax import NVTE_Bias_Type from transformer_engine.transformer_engine_jax import NVTE_Mask_Type @@ -252,28 +255,24 @@ def make_helper(attn_mask_type): (-1, -1) if window_size is None else window_size, ) - if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): - return False - - return True + return make_helper(attn_mask_type).is_fused_attn_kernel_available() def _obtain_batch_and_max_seqlen(qkv, qkv_layout): - match qkv_layout: - case QKVLayout.BS3HD | QKVLayout.T3HD: - assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = q_max_seqlen - case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD: - assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = qkv[1].shape[1] - case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD: - assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = qkv[1].shape[1] - case _: - raise ValueError(f"Unsupported {qkv_layout=}") + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = q_max_seqlen + elif qkv_layout.is_kvpacked(): + assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = qkv[1].shape[1] + elif qkv_layout.is_separate(): + assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = qkv[1].shape[1] + else: + raise ValueError(f"Unsupported {qkv_layout=}") return batch, q_max_seqlen, kv_max_seqlen @@ -289,7 +288,273 @@ def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: Q return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) -def fused_attn( +def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq): + # bincount map with 0s + bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1)) + seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) + seqlens = seqlens_with_zero[..., 1:] + + def _find_offsets(x): + same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) + first_column = x[..., :1] != 0 + same_as_previous = jnp.hstack((first_column, same_as_previous)) + return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))( + same_as_previous + ).squeeze(-1) + + offsets = _find_offsets(segment_ids) + return seqlens, offsets + + +def _mask_to_seqlens_offset(mask, max_segments_per_seq): + assert mask.shape[1] == 1 + row_ids = mask.squeeze(axis=1).max(axis=-1) + q_seqlen, q_offset = _get_seqlens_and_offsets(row_ids, max_segments_per_seq) + col_ids = mask.squeeze(axis=1).max(axis=-2) + kv_seqlen, kv_offset = _get_seqlens_and_offsets(col_ids, max_segments_per_seq) + return q_seqlen, q_offset, kv_seqlen, kv_offset + + +def _segment_ids_pos_to_seqlens_offsets( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + window_size, + max_segments_per_seq, +): + # (1 = attend, 0 = masked) + segment_mask = make_attention_mask( + segment_ids_q, + segment_ids_kv, + jnp.equal, + ) + segment_mask_with_id = make_attention_mask( + segment_ids_q, + segment_ids_kv, + lambda x, y: jnp.equal(x, y) * x, + ) + attn_mask = segment_mask + if attn_mask_type.is_causal(): + causal_mask = make_attention_mask( + segment_pos_q, + segment_pos_kv, + jnp.greater_equal, + ) + attn_mask = jnp.logical_and(segment_mask, causal_mask) + + swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + attn_mask = jnp.logical_and(attn_mask, swa_mask) + + attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) + q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( + attn_mask_with_id, max_segments_per_seq + ) + return q_seqlen, kv_seqlen, q_offset, kv_offset + + +def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): + # convert the mask to seqlens, mask doesn't support ragged offsets + if not attn_mask_type.is_padding(): + q_max_seqlen = segment_ids_q.shape[-1] + kv_max_seqlen = segment_ids_kv.shape[-1] + q_seq_lens = jnp.full_like(q_max_seqlen, q_max_seqlen, dtype=jnp.int32) + kv_seq_lens = jnp.full_like(kv_max_seqlen, kv_max_seqlen, dtype=jnp.int32) + else: + q_seq_lens = jnp.sum(segment_ids_q, axis=-1).astype(jnp.int32) + kv_seq_lens = jnp.sum(segment_ids_kv, axis=-1).astype(jnp.int32) + return q_seq_lens, kv_seq_lens + + +@jax.tree_util.register_pytree_node_class +class SequenceDescriptor: + """A class to descibe the sequences with flexible initialization. + - SequenceDescriptor.from_seqlens + For non-THD (non-packed) cases, where each batch has only 1 sequence. + - SequenceDescriptor.from_seqlens_and_offsets + For THD (packed) cases, where each batch may have not only 1 sequence. + - SequenceDescriptor.from_segment_ids_and_pos + Experimental feature for THD (packed) cases with context parallelism. + """ + + seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + seq_offsets: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + segment_ids: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + segment_pos: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + + def __init__(self, seqlens=None, seq_offsets=None, segment_ids=None, segment_pos=None): + """ + Initialize to Tuple(jnp.zeros, jnp.zeros) because the primitive only accepts pure jax array + """ + self.seqlens = (jnp.zeros(0), jnp.zeros(0)) if seqlens is None else seqlens + self.seq_offsets = (jnp.zeros(0), jnp.zeros(0)) if seq_offsets is None else seq_offsets + self.segment_ids = (jnp.zeros(0), jnp.zeros(0)) if segment_ids is None else segment_ids + self.segment_pos = (jnp.zeros(0), jnp.zeros(0)) if segment_pos is None else segment_pos + + def tree_flatten(self): + """ + Flatten method to register as a pytree node + """ + return ((self.seqlens, self.seq_offsets, self.segment_ids, self.segment_pos), None) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """ + Unflatten method to register as a pytree node + """ + del aux_data + return cls(*children) + + def get_seqlens_and_offsets( + self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq + ): + """ + Acquire the seqlens/offsets for cuDNN backend + """ + attn_mask_type = AttnMaskType(attn_mask_type) + qkv_layout = QKVLayout(qkv_layout) + q_segment_ids, kv_segment_ids = self.segment_ids + q_segment_pos, kv_segment_pos = self.segment_pos + assert q_segment_ids.shape == q_segment_pos.shape + assert kv_segment_ids.shape == kv_segment_pos.shape + # No segment_ids/segment_pos + if q_segment_ids.size + kv_segment_ids.size == 0: + return self.seqlens, self.seq_offsets + + if qkv_layout.is_thd(): + q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + else: + q_seqlens, kv_seqlens = _segment_ids_to_seqlens( + q_segment_ids, + kv_segment_ids, + attn_mask_type, + ) + q_offsets = kv_offsets = jnp.zeros(0) + return (q_seqlens, kv_seqlens), (q_offsets, kv_offsets) + + @classmethod + def _expand_to_pair( + cls, value: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Internal helper to ensure a single value expands into a pair (q, kv). + """ + if isinstance(value, tuple): + if len(value) != 2: + raise ValueError("Input tuple must have exactly 2 elements.") + return value + + if isinstance(value, jnp.ndarray): + return value, value # Duplicate for q=kv case + + raise TypeError( + "Expected a jax.numpy.ndarray or a tuple of two jax.numpy.ndarray, " + f"but got {type(value).__name__}." + ) + + @classmethod + def from_seqlens( + cls, + seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + ) -> SequenceDescriptor: + """ + Factory method for inputs with sequence lengths only (non-THD). + Args: + seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens): + - q_seqlens (jnp.ndarray): + Sequence lengths for the query, with shape [batch]. + - kv_seqlen (jnp.ndarray): + Sequence lengths for the key and value, with shape [batch]. + Return: + A SequenceDescriptor with only seqlens initialized. + """ + q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens) + return cls(seqlens=(q_seqlens, kv_seqlens)) + + @classmethod + def from_seqlens_and_offsets( + cls, + seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + seq_offsets: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + ) -> SequenceDescriptor: + """ + Factory method for inputs with sequence lengths and offsets (THD). + Args: + seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens): + - q_seqlens (jnp.ndarray): + Sequence lengths for the query, with shape [batch, max_seqlen]. + Unused positions are padded with -1. + - kv_seqlen (jnp.ndarray): + Sequence lengths for the key and value, with shape [batch, max_seqlen]. + Unused positions are padded with -1. + seq_offsets(Tuple(jnp.ndarray, jnp.ndarray)) = (q_offsets, kv_offsets) + - q_seq_offsets (jnp.ndarray): + The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. + Unused positions are padded with -1. + - kv_seq_offsets (jnp.ndarray): + The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. + Unused positions are padded with -1. + Return: + A SequenceDescriptor with seqlens/seq_offsets initialized. + """ + q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens) + q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets) + return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets)) + + @classmethod + def from_segment_ids_and_pos( + cls, + segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> SequenceDescriptor: + """ + Experimental factory method for inputs with segment IDs and optional positions. (THD) + Args: + segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): + - q_segment_ids (jnp.ndarray): + Query segment ids start with 1, with shape [batch, max_seqlen]. + 0s are treated as paddings. + - kv_segment_ids (jnp.ndarray): + Key, value segment ids start with 1, with shape [batch, max_seqlen]. + 0s are treated as paddings. + segment_pos(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_pos, kv_segment_pos) + - q_segment_pos (jnp.ndarray): + The position inside each segment for query, with shape [batch, max_seqlen]. + - kv_segment_pos (jnp.ndarray): + The position inside each segment for key, value, with shape [batch, max_seqlen]. + Return: + A SequenceDescriptor with segment_ids/segment_pos initialized. + """ + q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) + + if segment_pos is not None: + segment_pos = cls._expand_to_pair(segment_pos) + else: + + def generate_default_pos(segment_ids): + seqlen = segment_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) + + q_seg_pos = generate_default_pos(q_seg_ids) + kv_seg_pos = generate_default_pos(kv_seg_ids) + segment_pos = (q_seg_pos, kv_seg_pos) + + return cls( + segment_ids=(q_seg_ids, kv_seg_ids), + segment_pos=segment_pos, + ) + + +def _legacy_fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], mask: Optional[jnp.ndarray], @@ -372,10 +637,7 @@ def fused_attn( output = _fused_attn( qkv, bias, - q_seq_lens, - kv_seq_lens, - None, - None, + SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)), seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -414,63 +676,13 @@ def fused_attn_thd( context_parallel_axis: str = "", ): """ - (Experimental) Perform THD (packed) cuDNN fused attention. - - This function implements the following formula: - BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 - Args: - qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors. - It supports three formats: - - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key, - and value have the same shape (e.g., self-attention). - - `(query, kv_packed)`: For separate query and KV packed format, typically used when - query has a different shape (e.g., cross-attention). - - `(query, key, value)`: For separate query, key, and value tensors. - bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. - q_seqlen (jnp.ndarray): - Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are - padded with -1. - kv_seqlen (jnp.ndarray): - Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions - are padded with -1. - q_seq_offsets (jnp.ndarray): - The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. - Unused positions are padded with -1. - kv_seq_offsets (jnp.ndarray): - The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. - Unused positions are padded with -1. - seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. - scaling_factor (float): Scaling factor for the attention scores. - dropout_probability (float): Dropout probability to apply during attention. - is_training (bool): Flag indicating whether the model is in training mode. - max_segments_per_seq (int): - Indicating the maximum number of segments inside a sequence. This parameter is to - constrain the limit usage and need to be static during the e2e training. The XLA compile - time and memory consumption is proportional to `max_segments_per_seq`. - window_size (Optional[Tuple[int, int]]): - Sliding window size. - context_parallel_causal_load_balanced (bool): - Indicates the sequences are ordered for causal mask load balancing when running context parallelism. - context_parallel_axis (str): The name of the context parallel axis. - Returns: - (jnp.ndarray): The output tensor from the fused attention. - - Examples: - >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens - >>> b, s, h, d = 2, 4, 12, 64 - >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) - >>> # 3 segments in first seq, 2 segments in second seq - >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]]) - >>> # seq_offsets need to include the end offset of the last segments - >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]]) - >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens, - q_seq_offsets, kv_seq_offsets, None, - AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, - QKVLayout.T3HD, 0.125, 0, True, 3) + Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor """ + warnings.warn( + "fused_attn_thd is deprecated, please use fused_attn with SequenceDescriptor", + DeprecationWarning, + ) + assert ( qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." @@ -497,10 +709,9 @@ def fused_attn_thd( output = _fused_attn( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + SequenceDescriptor.from_seqlens_and_offsets( + (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets) + ), seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -518,15 +729,12 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - q_seq_lens: jnp.ndarray, - kv_seq_lens: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], - seed: jnp.ndarray, + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, qkv_layout: QKVLayout, @@ -542,10 +750,7 @@ def _fused_attn( output, _ = _fused_attn_fwd_rule( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type, attn_mask_type, @@ -565,10 +770,7 @@ def _fused_attn( def _fused_attn_fwd_rule( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type, attn_mask_type, @@ -585,10 +787,7 @@ def _fused_attn_fwd_rule( output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, @@ -608,10 +807,7 @@ def _fused_attn_fwd_rule( return output, ( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, softmax_aux, rng_state, output, @@ -636,10 +832,7 @@ def _fused_attn_bwd_rule( ( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, softmax_aux, rng_state, output, @@ -651,10 +844,7 @@ def _fused_attn_bwd_rule( rng_state, output, dz, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, qkv_layout=qkv_layout.value, @@ -669,7 +859,137 @@ def _fused_attn_bwd_rule( ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None - return grad_qkv, grad_bias, None, None, None, None, None + return ( + grad_qkv, + grad_bias, + None, + None, + ) _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) + + +def fused_attn( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int = 1, + window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", +): + """ + Perform cuDNN fused attention. + + This function implements the following formula: + BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + Args: + qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors. + It supports three formats: + - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key, + and value have the same shape (e.g., self-attention). + - `(query, kv_packed)`: For separate query and KV packed format, typically used when + query has a different shape (e.g., cross-attention). + - `(query, key, value)`: For separate query, key, and value tensors. + bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. + sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence. + seed (Optional[jnp.ndarray]): Optional random seed for dropout. + attn_bias_type (NVTE_Bias_Type): Type of attention bias. + attn_mask_type (NVTE_Mask_Type): Type of attention mask. + qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + scaling_factor (float): Scaling factor for the attention scores. + dropout_probability (float): Dropout probability to apply during attention. + is_training (bool): Flag indicating whether the model is in training mode. + max_segments_per_seq (int): + Indicating the maximum number of segments inside a sequence. This parameter is to + constrain the limit usage and need to be static during the e2e training. The XLA compile + time and memory consumption is proportional to `max_segments_per_seq`. + window_size (Optional[Tuple[int, int]]): + Sliding window size. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. + Returns: + (jnp.ndarray): The output tensor from the fused attention. + + Examples (non-THD, also known as non-packed): + >>> # q_segment_ids = [[1, 1, 1, 0], [1, 1, 0, 0]], 0 means padded tokens + >>> # kv_segment_ids = [[1, 0, 0, 0], [1, 1, 0, 0]], 0 means padded tokens + >>> b, s, h, d = 2, 4, 12, 64 + >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) + >>> q_seq_lens = jnp.asarray([3, 2]) + >>> kv_seq_lens = jnp.asarray([1, 2]) + >>> sequence_desc = SequenceDescriptor.from_seqlens( + seqlens=(q_seq_lens, kv_seq_lens)) + >>> out = fused_attn((qkv,), None, sequence_desc, None, + AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, + QKVLayout.BS3HD, 0.125, 0, True, 3) + + Examples (THD, also known as packed): + >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens + >>> # segment_pos = [[0, 1, 0, 0], [0, 1, 0, 1]] + >>> b, s, h, d = 2, 4, 12, 64 + >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) + >>> # 3 segments in first seq, 2 segments in second seq + >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]]) + >>> # seq_offsets need to include the end offset of the last segments + >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]]) + >>> sequence_desc = SequenceDescriptor.from_seqlens_and_offsets( + seqlens=(q_seq_lens, kv_seq_lens), + seq_offsets=(q_seq_offsets, kv_seq_offsets)) + >>> out = fused_attn((qkv,), None, sequence_desc, None, + AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, + QKVLayout.T3HD, 0.125, 0, True, 3) + """ + if isinstance(sequence_descriptor, jnp.ndarray): + warnings.warn( + "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " + + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.", + DeprecationWarning, + ) + if max_segments_per_seq != 1: + raise ValueError("Passing mask is only supported for non-THD case.") + return _legacy_fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + ) + output = _fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + ) + + return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 3a116ffb63..ae3cfddccc 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -17,7 +17,7 @@ from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi -from transformer_engine.jax.attention import CPStrategy +from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import ( @@ -211,9 +211,8 @@ def generate_cu_seqlen(actual_seqlen): """ Generating cumsum seqlen for a batch """ - cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1) - cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen) - cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1) + actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen) + cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True) return cu_seqlen @@ -224,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (9,) + impl_static_args = (13,) inner_primitive = None outer_primitive = None @@ -234,11 +233,15 @@ def abstract( k_aval, v_aval, bias_aval, + seed_aval, q_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval, _q_seq_offsets, _k_seq_offsets, - seed_aval, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config: _FusedAttnConfig, ): @@ -354,11 +357,15 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config: _FusedAttnConfig, ): @@ -387,11 +394,15 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -417,11 +428,11 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, ] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ @@ -466,15 +477,35 @@ def impl( k, v, bias, + seed, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config: _FusedAttnConfig, ): assert FusedAttnFwdPrimitive.inner_primitive is not None + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): @@ -517,6 +548,7 @@ def convert_to_2d(offsets, batch, max_seqlen): fill_value = 0 else: fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) @@ -524,15 +556,17 @@ def convert_to_2d(offsets, batch, max_seqlen): # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + # Gather valid q_seq_offsets, which is greater and equal to 0 # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0) - k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0) - - # Set the unused position to max size (batch * max_seqlen) + # And set the unused position to max size (batch * max_seqlen) # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] - q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets) - k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) @@ -542,11 +576,15 @@ def convert_to_2d(offsets, batch, max_seqlen): k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) return output, softmax_aux, rng_state @@ -555,7 +593,7 @@ def convert_to_2d(offsets, batch, max_seqlen): def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims + q_bdim, _, _, _, seed_bdim, *_ = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim return ( @@ -600,7 +638,9 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @@ -616,7 +656,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (12,) + impl_static_args = (16,) inner_primitive = None outer_primitive = None @@ -634,6 +674,10 @@ def abstract( kv_seqlen_or_cu_seqlen_aval, _q_seq_offsets, _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config, ): @@ -718,6 +762,10 @@ def lowering( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, *, config, ): @@ -754,6 +802,10 @@ def lowering( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -839,10 +891,30 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config, ): assert FusedAttnBwdPrimitive.inner_primitive is not None + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): @@ -893,15 +965,17 @@ def convert_to_2d(offsets, batch, max_seqlen): # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + # Gather valid q_seq_offsets, which is greater and equal to 0 # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0) - k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0) - - # Set the unused position to max size (batch * max_seqlen) + # And set the unused position to max size (batch * max_seqlen) # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] - q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets) - k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) @@ -919,6 +993,10 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) return dq, dk, dv, dbias @@ -975,6 +1053,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( q, @@ -989,6 +1071,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) global_dbias = local_dbias @@ -1240,10 +1326,26 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) @@ -1280,11 +1382,15 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): k_unmasked, v_unmasked, bias, + seed, q_seqlen_for_step, kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(), ) results.append((output, softmax_aux, rng_state)) @@ -1357,13 +1463,31 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. def _cross_attn_bwd( - idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen + idx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) @@ -1402,6 +1526,10 @@ def _cross_attn_bwd( kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(), ) @@ -1433,6 +1561,10 @@ def _cross_attn_bwd( doutput, q_seqlen, kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ) for idx in range(cp_size) ] @@ -1595,7 +1727,9 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) def ring_attn_fwd_impl( @@ -1603,11 +1737,15 @@ def ring_attn_fwd_impl( k, v, bias, + seed, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): _not_used = jnp.zeros(0, dtype=v.dtype) @@ -1644,12 +1782,16 @@ def mask_compute(attn_mask_type): kv, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, - helper.get_step_config(attn_mask_type), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(attn_mask_type), ) return output_per_step, softmax_aux_per_step @@ -1665,11 +1807,15 @@ def half_kv_no_mask_compute(): kv_part, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) return output_per_step, softmax_aux_per_step @@ -1683,11 +1829,15 @@ def half_q_no_mask_compute(): kv, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) @@ -1805,6 +1955,10 @@ def ring_attn_bwd_impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): _not_used = jnp.zeros(0, dtype=output.dtype) @@ -1849,6 +2003,10 @@ def mask_compute(attn_mask_type): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(attn_mask_type), ) return dq_per_step, dk_dv_per_step, dbias_per_step @@ -1873,6 +2031,10 @@ def half_kv_no_mask_compute(): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) dk_dv_per_step = jnp.concat( @@ -1907,6 +2069,10 @@ def half_q_no_mask_compute(): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) @@ -1975,10 +2141,7 @@ def _maybe_context_parallel_axis(cp_axis: str): def fused_attn_fwd( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, seed: Optional[jnp.ndarray], attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, @@ -2031,14 +2194,9 @@ def fused_attn_fwd( (jnp.ndarray): The output tensor from the fused attention. """ seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training) - - assert (q_seq_offsets is None) == ( - kv_seq_offsets is None - ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values." - is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD - # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) + match qkv_layout: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" @@ -2071,21 +2229,19 @@ def fused_attn_fwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - primative = None + primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive case CPStrategy.RING: - primative = FusedRingAttnFwdPrimitive.outer_primitive + primitive = FusedRingAttnFwdPrimitive.outer_primitive - return primative.bind( + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + return primitive.bind( *qkv_for_primitive, bias, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, seed, + *seq_desc_flatten, config=fused_config, ) @@ -2097,10 +2253,7 @@ def fused_attn_bwd( rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, - q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, qkv_layout: NVTE_QKV_Layout, @@ -2155,12 +2308,6 @@ def fused_attn_bwd( same format as the input `qkv`. - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`. """ - - assert (q_seq_offsets is None) == ( - kv_seq_offsets is None - ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values." - is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD - # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) @@ -2196,24 +2343,23 @@ def fused_attn_bwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - primative = None + primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive case CPStrategy.RING: - primative = FusedRingAttnBwdPrimitive.outer_primitive + primitive = FusedRingAttnBwdPrimitive.outer_primitive + + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) - *qkv_grads, bias_grad = primative.bind( + *qkv_grads, bias_grad = primitive.bind( *qkv_for_primitive, bias, softmax_aux, rng_state, output, doutput, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, + *seq_desc_flatten, config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index dc857aa22c..7447cd1871 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -213,14 +213,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto layout_group = nvte_get_qkv_layout_group(qkv_layout); static void FusedAttnForwardImpl( - cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, - void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, - void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, - size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, - size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, - bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, + void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, + size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -303,11 +303,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *k = buffers[1]; void *v = buffers[2]; void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; - void *seed = buffers[8]; + void *seed = buffers[4]; + void *q_cu_seqlens = buffers[5]; + void *kv_cu_seqlens = buffers[6]; + void *q_seq_offsets = is_ragged ? buffers[7] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[8] : nullptr; /* Output buffer from XLA */ void *output = buffers[9]; @@ -316,7 +316,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *workspace = buffers[12]; FusedAttnForwardImpl( - stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed, + stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, @@ -354,24 +354,24 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, - Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, - Buffer_Type seed_buf, Result_Type output_buf, + Variadic_Buffer_Type _unused_args, Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf, Result_Type workspace_buf, Dictionary attrs) { FUSED_ATTN_FFI_GET_ATTRS; FusedAttnForwardImpl( stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), - bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), - is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, - is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), - output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), - workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, - scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), + softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, + head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, + mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, + window_size_right); return ffi_with_cuda_error_check(); } @@ -383,11 +383,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, .Arg() // k .Arg() // v .Arg() // bias + .Arg() // seed_buf .Arg() // q_cu_seqlens .Arg() // kv_cu_seqlens .Arg() // q_seq_offsets .Arg() // k_seq_offsets - .Arg() // seed_buf + .RemainingArgs() // _cp_aux_args unused .Ret() // output .Ret() // softmax_aux .Ret() // rng_state @@ -642,9 +643,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, - Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - Dictionary attrs) { + Variadic_Buffer_Type _unused_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, + Result_Type workspace_buf, Dictionary attrs) { FUSED_ATTN_FFI_GET_ATTRS; FusedAttnBackwardImpl( @@ -677,6 +678,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Arg() // kv_cu_seqlens .Arg() // q_seq_offsets .Arg() // k_seq_offsets + .RemainingArgs() // _cp_aux_args unused .Ret() // dq .Ret() // dk .Ret() // dv From 2fce82b725092339300f3a9e955912938280013f Mon Sep 17 00:00:00 2001 From: hx Date: Tue, 28 Jan 2025 00:01:30 +0800 Subject: [PATCH 061/239] [MoE][PyTorch] Add mask-based MoE permutation (#1373) * add mask-based moe permutation * change moe_chunk_permute to moe_sort_chunks_by_indices * fix __all__ in pytorch/permutation.py * fix func/var names and typos; update tols in UT --------- Signed-off-by: Hongxiao Bai Co-authored-by: Phuong Nguyen Co-authored-by: Tim Moon --- docs/api/pytorch.rst | 2 + tests/pytorch/test_permutation.py | 673 +++++++++++++++++- transformer_engine/pytorch/__init__.py | 6 +- transformer_engine/pytorch/permutation.py | 390 +++++++++- transformer_engine/pytorch/triton/__init__.py | 5 + .../pytorch/triton/permutation.py | 599 ++++++++++++++++ 6 files changed, 1625 insertions(+), 50 deletions(-) create mode 100644 transformer_engine/pytorch/triton/__init__.py create mode 100644 transformer_engine/pytorch/triton/permutation.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 986d79808c..43001feeb3 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -52,6 +52,8 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_unpermute +.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index + .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 2fd8e49114..c29c01b433 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -2,11 +2,17 @@ # # See LICENSE for license information. +import random + import torch import pytest from typing import Dict, List -from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute +from transformer_engine.pytorch import ( + moe_permute as te_permute, + moe_unpermute as te_unpermute, + moe_sort_chunks_by_index as te_sort_chunks_by_index, +) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor @@ -18,7 +24,7 @@ torch.cuda.manual_seed(seed) -def pytorch_permute(tokens, indices, num_out_tokens: int = None): +def pytorch_permute_index_map(tokens, indices, num_out_tokens: int = None): """ Permute the tokens based on the indices. Token with the same index will be grouped together. The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. @@ -50,7 +56,7 @@ def pytorch_permute(tokens, indices, num_out_tokens: int = None): return permuted_tokens, sorted_indices -def pytorch_unpermute( +def pytorch_unpermute_index_map( permuted_tokens: torch.Tensor, sorted_indices: torch.Tensor, probs: torch.Tensor = None, @@ -95,6 +101,86 @@ def pytorch_unpermute( return unpermuted_tokens +def pytorch_permute_mask_map(tokens, routing_map): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + Args: + tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. + routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. + """ + num_tokens, _ = tokens.shape + num_experts = routing_map.shape[1] + + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.bool().T.contiguous() + + # Create a dense expert-to-token mapping from the sparse token-to-expert mapping + token_indices = ( + torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) + ) + sorted_indices = token_indices.masked_select(routing_map) + + # use the mapping to permute the tokens + permuted_input = tokens.index_select(0, sorted_indices) + + return permuted_input, sorted_indices + + +def pytorch_unpermute_mask_map( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + restore_shape: torch.Size, + probs: torch.Tensor = None, + routing_map: torch.Tensor = None, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + Args: + permuted_tokens (torch.Tensor): The permuted token tensor. + sorted_indices (torch.Tensor): The indices used to sort the tokens. + restore_shape (torch.Size): The shape of the unpermuted tensor. + probs (torch.Tensor, optional): The unpermuted probs tensor, + routing_map (torch.Tensor, optional): Token to expert mapping, shape + [num_tokens, num_experts]. + + Returns: + torch.Tensor: The tokens restored to their original order. + """ + _, hidden = restore_shape + + if probs is not None: + assert routing_map is not None, "Mask must be provided to permute the probs." + permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous()) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + # Create an output tensor filled with zeros + output_tokens = torch.zeros( + restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype + ) + # Scatter add the permuted_input back to the original positions + output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens) + return output_tokens + + +def pytorch_sort_chunks_by_index( + input: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, +): + """ + Split and sort the input tensor based on the split_sizes and sorted indices. + return a tuple of (output, row_id_map). row_id_map is only used when fused=True. + """ + input = torch.split(input, split_sizes.tolist(), dim=0) + output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) + return output + + def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: """Estimated tolerances for a datatype @@ -112,7 +198,7 @@ def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: raise ValueError(f"Unsuppored dtype ({te_dtype})") -def _test_permutation( +def _test_permutation_index_map( te_dtype, num_tokens, num_expert, @@ -132,7 +218,8 @@ def _test_permutation( num_out_tokens = num_tokens * topK print( - f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + "index map:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) fp8 = False @@ -198,7 +285,7 @@ def _test_permutation( # PyTorch Permutation # ################################################################################################################################### - pytorch_permute_output, sorted_indices = pytorch_permute( + pytorch_permute_output, sorted_indices = pytorch_permute_index_map( pytorch_permute_fwd_input, indices, num_out_tokens ) pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) @@ -206,7 +293,7 @@ def _test_permutation( pytorch_unpermute_fwd_input = pytorch_permute_output.detach() pytorch_unpermute_fwd_input.requires_grad_(True) - pytorch_unpermute_output = pytorch_unpermute( + pytorch_unpermute_output = pytorch_unpermute_index_map( pytorch_unpermute_fwd_input, sorted_indices, probs=probs ) pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) @@ -220,7 +307,9 @@ def _test_permutation( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens) + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, indices, num_out_tokens, map_type="index" + ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -231,7 +320,9 @@ def _test_permutation( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) + te_unpermute_output = te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" + ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -300,10 +391,10 @@ def backward_wrapper( if BENCHMARK: t1 = perf_test_cuda_kernel( - lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) + lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens) ) t2 = perf_test_cuda_kernel( - lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens) + lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens, map_type="index") ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -328,10 +419,12 @@ def backward_wrapper( print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") t1 = perf_test_cuda_kernel( - lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) + lambda: pytorch_unpermute_index_map( + pytorch_unpermute_fwd_input, sorted_indices, probs=probs + ) ) t2 = perf_test_cuda_kernel( - lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) + lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs, map_type="index") ) print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -362,6 +455,416 @@ def backward_wrapper( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_probs, + BENCHMARK=False, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "mask map:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + permute_bwd_input = torch.rand( + size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + permute_fwd_input = Float8Tensor.to_float8( + permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + permute_bwd_input = Float8Tensor.to_float8( + permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + unpermute_bwd_input = Float8Tensor.to_float8( + unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + + pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + restore_shape = pytorch_permute_fwd_input.shape + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = None + if with_probs: + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs.requires_grad_(True) + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute_mask_map( + pytorch_permute_fwd_input, routing_map + ) + pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) + + pytorch_unpermute_fwd_input = pytorch_permute_output.detach() + pytorch_unpermute_fwd_input.requires_grad_(True) + + pytorch_unpermute_output = pytorch_unpermute_mask_map( + pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask" + ) + te_permute_output.backward(te_permute_bwd_input, retain_graph=True) + + te_probs = None + if with_probs: + te_probs = probs.detach() + te_probs.requires_grad_(True) + te_unpermute_fwd_input = te_permute_output.detach() + te_unpermute_fwd_input.requires_grad_(True) + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + + te_unpermute_output = te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" + ) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_permute_output_ = te_permute_output.from_float8(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) + te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + else: + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + ) + + if not pytorch_permute_fwd_input.numel(): + print("Empty pytorch_permute_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False + ): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map) + ) + t2 = perf_test_cuda_kernel( + lambda: te_permute(te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask") + ) + print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_permute_output, + pytorch_permute_bwd_input, + forward_input=[pytorch_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: pytorch_unpermute_mask_map( + pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map + ) + ) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" + ) + ) + print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_unpermute_output, + pytorch_unpermute_bwd_input, + forward_input=( + [pytorch_unpermute_fwd_input, probs] + if with_probs + else [pytorch_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_unpermute_output, + te_unpermute_bwd_input, + forward_input=( + [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + +def _test_moe_chunk_sort( + te_dtype, + num_tokens, + num_expert, + tp_size, + hidden_size, + BENCHMARK=False, +): + print( + "chunk permute:" + f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + + fwd_input = Float8Tensor.to_float8( + fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + bwd_input = Float8Tensor.to_float8( + bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + + pytorch_fwd_input = fwd_input.from_float8(torch.float16) + pytorch_bwd_input = bwd_input.from_float8(torch.float16) + else: + pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_fwd_input.requires_grad_(True) + + _split_sizes = [0] * (num_expert * tp_size) + for _ in range(num_tokens): + idx = random.randint(0, num_expert * tp_size - 1) + _split_sizes[idx] += 1 + split_sizes = torch.tensor(_split_sizes, dtype=torch.int32).ravel() + split_sizes_cuda = split_sizes.to(device="cuda") + + _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32) + sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel() + sorted_idxs_cuda = sorted_idxs.to(device="cuda") + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_output = pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs) + pytorch_output.backward(pytorch_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach() + te_fwd_input.requires_grad_(True) + te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach() + + te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) + te_output.backward(te_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_output_ = te_output.from_float8(torch.float32) + te_fwd_input_grad = te_fwd_input.grad.from_float8(torch.float32) + else: + te_output_ = te_output.float() + te_fwd_input_grad = te_fwd_input.grad.float() + + torch.testing.assert_close( + pytorch_output.float(), + te_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_fwd_input.grad.float(), + te_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + + if not pytorch_fwd_input.numel(): + print("Empty pytorch_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False + ): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs) + ) + t2 = perf_test_cuda_kernel( + lambda: te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) + ) + print(f"chunk sort\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_output, + pytorch_bwd_input, + forward_input=[pytorch_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_output, + te_bwd_input, + forward_input=[te_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + def perf_test_cuda_kernel(cuda_kernel_fn): if torch.cuda.is_available(): # create CUDA event @@ -396,7 +899,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -def test_permutation( +def test_permutation_index_map( te_dtype, num_tokens, num_expert, @@ -407,7 +910,36 @@ def test_permutation( with_probs = True BENCHMARK = False - _test_permutation( + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation_mask_map( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, @@ -430,7 +962,37 @@ def test_permutation( @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -def test_permutation_fp8( +def test_permutation_index_map_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_mask_map_fp8( te_dtype, num_tokens, num_expert, @@ -441,7 +1003,7 @@ def test_permutation_fp8( with_probs = True BENCHMARK = False - _test_permutation( + _test_permutation_mask_map( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, @@ -457,7 +1019,7 @@ def test_permutation_fp8( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -def test_permutation_topk1_no_probs( +def test_permutation_index_map_topk1_no_probs( te_dtype, num_tokens, num_expert, @@ -468,7 +1030,7 @@ def test_permutation_topk1_no_probs( with_probs = False BENCHMARK = False - _test_permutation( + _test_permutation_index_map( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, @@ -480,6 +1042,57 @@ def test_permutation_topk1_no_probs( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_mask_map_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_chunk_permutation( + te_dtype, + num_tokens, + num_expert, + tp_size, + hidden_size, +): + BENCHMARK = False + + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + tp_size=tp_size, + hidden_size=hidden_size, + BENCHMARK=BENCHMARK, + ) + + def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) @@ -497,7 +1110,18 @@ def test_permutation_single_case(): with_probs = True Benchmark = True - _test_permutation( + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + _test_permutation_mask_map( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, @@ -508,6 +1132,15 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + tp_size=4, + hidden_size=hidden_size, + BENCHMARK=Benchmark, + ) + if __name__ == "__main__": test_permutation_single_case() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 9b51d1369a..91d3772fd7 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -74,7 +74,11 @@ def _load_library(): from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.transformer import TransformerLayer -from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute +from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_unpermute, + moe_sort_chunks_by_index, +) from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 90cb5cc021..264b620be8 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -2,24 +2,26 @@ # # See LICENSE for license information. -"""Linear API""" +"""MoE Permutaion API""" import warnings from typing import Tuple import torch import transformer_engine_torch as tex -from .constants import TE_DType -from .float8_tensor import Float8Tensor +import transformer_engine.pytorch.triton.permutation as triton_permutation +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.float8_tensor import Float8Tensor __all__ = [ "moe_permute", "moe_unpermute", + "moe_sort_chunks_by_index", ] -class _moe_permute(torch.autograd.Function): - """functional Permute""" +class _moe_permute_index_map(torch.autograd.Function): + """functional Permute with index router map""" workspace = None max_expanded_token_num = 0 @@ -28,7 +30,7 @@ class _moe_permute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, - indices: torch.Tensor, + index: torch.Tensor, num_out_tokens: int, max_token_num: int, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -39,9 +41,9 @@ def forward( # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." - assert indices.is_cuda, "TransformerEngine needs CUDA." + assert index.is_cuda, "TransformerEngine needs CUDA." # Shape check - assert inp.size(0) == indices.size(0), "Permute not possible" + assert inp.size(0) == index.size(0), "Permute not possible" # Data type check fp8 = isinstance(inp, Float8Tensor) @@ -51,27 +53,27 @@ def forward( inp = inp._data else: dtype = TE_DType[inp.dtype] - if indices.dtype != torch.int32: + if index.dtype != torch.int32: warnings.warn( - f"The data type of the input `indices` of Permute is {indices.dtype}! " + f"The data type of the input `index` of Permute is {index.dtype}! " "The recommended type is torch.int32." ) - indices = indices.to(torch.int32) + index = index.to(torch.int32) - topK = indices.size(1) + topK = index.size(1) input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK - if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: - _moe_permute.max_expanded_token_num = input_max_expanded_token_num - _moe_permute.workspace = [] + if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num: + _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num + _moe_permute_index_map.workspace = [] - permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( + permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd( inp, dtype, - indices, + index, num_out_tokens, - _moe_permute.workspace, - _moe_permute.max_expanded_token_num, + _moe_permute_index_map.workspace, + _moe_permute_index_map.max_expanded_token_num, ) if fp8: @@ -80,8 +82,8 @@ def forward( ) ctx.row_id_map = row_id_map - ctx.num_tokens = indices.size(0) - ctx.topK = indices.size(1) + ctx.num_tokens = index.size(0) + ctx.topK = index.size(1) ctx.fp8 = fp8 return permuted_act, row_id_map @@ -122,8 +124,8 @@ def backward( return act_grad, None, None, None -class _moe_unpermute(torch.autograd.Function): - """functional Unpermute""" +class _moe_unpermute_index_map(torch.autograd.Function): + """functional Unpermute with index router map""" @staticmethod def forward( @@ -225,21 +227,238 @@ def backward( return act_grad, None, prob_grad +class _moe_permute_mask_map(torch.autograd.Function): + """functional Permute with mask router map""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) + + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert routing_map.is_cuda, "TransformerEngine needs CUDA." + + assert inp.size(0) == routing_map.size(0), "Permute not possible" + num_tokens, hidden_size = inp.size() + num_experts = routing_map.size(1) + assert ( + num_out_tokens is not None + ), "num_out_tokens must be provided to the fused permute function." + + row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + output = triton_permutation.permute_with_mask_map( + inp, + row_id_map, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) + if fp8: + output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv) + + ctx.save_for_backward(row_id_map) + ctx.num_experts = num_experts + ctx.num_tokens = num_tokens + ctx.hidden_size = hidden_size + return output, row_id_map + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + _, + ) -> Tuple[torch.Tensor, ...]: + # pylint: disable=missing-function-docstring + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None + + act_grad = None + if ctx.needs_input_grad[0]: + (row_id_map,) = ctx.saved_tensors + fp8 = isinstance(permuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + permuted_act_grad = permuted_act_grad._data + else: + fp8_dtype = None + act_grad = triton_permutation.unpermute_with_mask_map( + permuted_act_grad, + row_id_map, + None, + ctx.num_tokens, + ctx.num_experts, + ctx.hidden_size, + fp8_dtype, + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv * ctx.num_experts, + ) + return act_grad, None, None + + +class _moe_unpermute_mask_map(torch.autograd.Function): + """functional Unpermute with mask router map""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + restore_shape: torch.Size, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + if not inp.numel(): + ctx.probs = probs + return inp + + if restore_shape is None: + restore_shape = inp.shape + num_tokens, hidden_size = restore_shape + num_experts = row_id_map.size(0) + + with_probs = probs is not None + if with_probs: + assert probs.is_cuda, "TransformerEngine needs CUDA." + if probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {probs.dtype}! " + "The recommended type is torch.float32." + ) + probs = probs.to(torch.float32) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + if not with_probs: + fp8_scale_inv = inp._scale_inv * num_experts + else: + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + fp8_dtype = None + unpermuted_output = triton_permutation.unpermute_with_mask_map( + inp, + row_id_map, + probs, + num_tokens, + num_experts, + hidden_size, + fp8_dtype=fp8_dtype, + ) + if fp8: + unpermuted_output = Float8Tensor( + data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + + if with_probs: + ctx.save_for_backward(inp, row_id_map, probs) + else: + ctx.save_for_backward(row_id_map) + ctx.num_experts = num_experts + ctx.num_tokens = num_tokens + ctx.num_permuted_tokens = inp.size(0) + ctx.hidden_size = hidden_size + ctx.with_probs = with_probs + return unpermuted_output + + @staticmethod + def backward(ctx, unpermuted_act_grad): + # pylint: disable=missing-function-docstring + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.probs, None + + act_grad = None + probs_grad = None + if ctx.needs_input_grad[0]: + if ctx.with_probs: + fwd_input, row_id_map, probs = ctx.saved_tensors + else: + (row_id_map,) = ctx.saved_tensors + + fp8 = isinstance(unpermuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = unpermuted_act_grad._fp8_dtype + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + else: + fp8_dtype = None + + if ctx.with_probs: + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs( + unpermuted_act_grad, + row_id_map, + fwd_input, + probs, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + fp8_dtype, + ) + else: + act_grad = triton_permutation.permute_with_mask_map( + unpermuted_act_grad, + row_id_map, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + ) + + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + + if not ctx.needs_input_grad[2]: + probs_grad = None + return act_grad, None, probs_grad, None + + def moe_permute( inp: torch.Tensor, - indices: torch.Tensor, + routing_map: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1, + map_type: str = "mask", ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Permute the tokens based on the indices. Token with the same index will be grouped together. + Permute the tokens based on the routing_map. Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. Parameters ---------- inp: torch.Tensor Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. - indices: torch.Tensor - The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. + routing_map: torch.Tensor + The token to expert mapping tensor. + If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'. + The values in it are the routed expert indices. num_out_tokens: int, default = -1 The effective output token count, representing the number of tokens not dropped. By default, set to '-1', meaning no tokens are dropped. @@ -247,14 +466,23 @@ def moe_permute( The maximum number of tokens, used for workspace allocation. By default, set to '-1', meaning the calculation of the size of workspace is automatically taken over by the operator. + map_type: str, default = 'mask' + Type of the routing map tensor. + Options are: 'mask', 'index'. """ - return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num) + if map_type == "index": + return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) + if map_type == "mask": + return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens) + raise ValueError("map_type should be one of 'mask' or 'index'") def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor = None, + restore_shape: torch.Tensor = None, + map_type: str = "mask", ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -271,5 +499,109 @@ def moe_unpermute( The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. + restore_shape: torch.Tensor + The output shape after the unpermute operation. + map_type: str, default = 'mask' + Type of the routing map tensor. Should be the same as the value passed to moe_permute. + Options are: 'mask', 'index'. + """ + if map_type == "index": + return _moe_unpermute_index_map.apply(inp, row_id_map, probs) + if map_type == "mask": + return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape) + raise ValueError("map_type should be one of 'mask' or 'index'") + + +class _moe_chunk_sort(torch.autograd.Function): + """functional MoE chunk permute""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) + + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert split_sizes.is_cuda, "TransformerEngine needs CUDA." + assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + + num_tokens, hidden_size = inp.shape + num_splits = split_sizes.size(0) + assert num_splits == sorted_idxs.size(0) + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + output, row_id_map = triton_permutation.sort_chunks_by_idx( + inp, + split_sizes, + sorted_idxs, + num_tokens, + hidden_size, + num_splits, + ) + if fp8: + output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv) + + ctx.save_for_backward(row_id_map) + ctx.num_tokens = num_tokens + ctx.hidden_size = hidden_size + return output + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, ...]: + # pylint: disable=missing-function-docstring + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None + + act_grad = None + if ctx.needs_input_grad[0]: + (row_id_map,) = ctx.saved_tensors + fp8 = isinstance(permuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + permuted_act_grad = permuted_act_grad._data + act_grad = triton_permutation.sort_chunks_by_map( + permuted_act_grad, + row_id_map, + ctx.num_tokens, + ctx.hidden_size, + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + return act_grad, None, None + + +def moe_sort_chunks_by_index( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_index: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split and sort the input tensor based on the split_sizes and sorted indices. + The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted + according to the sorted_indices. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + split_sizes: torch.Tensor + Chunk sizes of the inp tensor along the 0-th dimension. + sorted_indices: torch.Tensor + Chunk indices used to permute the chunks. """ - return _moe_unpermute.apply(inp, row_id_map, probs) + return _moe_chunk_sort.apply(inp, split_sizes, sorted_index) diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py new file mode 100644 index 0000000000..76c9b98d0e --- /dev/null +++ b/transformer_engine/pytorch/triton/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Kernels written with OpenAI Triton.""" diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py new file mode 100644 index 0000000000..767362e8c1 --- /dev/null +++ b/transformer_engine/pytorch/triton/permutation.py @@ -0,0 +1,599 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Permutation kernels written with OpenAI Triton.""" + +from typing import Union + +import torch +import triton +import triton.language as tl + +from transformer_engine_torch import DType as TE_DType + + +@triton.jit +def _row_id_map_pass_1_kernel( + # pointers + routing_map_ptr, + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # strides + stride_routing_map_token, + stride_routing_map_expert, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + expert_token_mask = tl.load( + routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, + mask=(offset < num_tokens), + other=0, + ).to(tl.int64) + row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask + tl.store( + row_id_map_ptr + pid_m * num_tokens + offset, + row_id_within_token_block, + mask=offset < num_tokens, + ) + n_tokens_per_block = tl.sum(expert_token_mask) + tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block) + + +@triton.jit +def _row_id_map_pass_2_kernel( + # pointers + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # metas + WORKSPACE_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + row_id_within_token_block = tl.load( + row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 + ) + + workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) + n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) + row_id = tl.where( + row_id_within_token_block == 0, + -1, + row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, + ) + tl.store( + row_id_map_ptr + pid_m * num_tokens + offset, + row_id, + mask=(offset < num_tokens), + ) + + +def make_row_id_map( + routing_map: torch.Tensor, + num_tokens: int, + num_experts: int, +): + # pylint: disable=missing-function-docstring + row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda") + block_size = 256 + grid = (num_experts, triton.cdiv(num_tokens, block_size)) + workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda") + # block cumsum + _row_id_map_pass_1_kernel[grid]( + routing_map, + row_id_map, + workspace_tensor, + num_tokens, + routing_map.stride(0), + routing_map.stride(1), + block_size, + ) + # cumsum all and process the mask + _row_id_map_pass_2_kernel[grid]( + row_id_map, + workspace_tensor, + num_tokens, + triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), + block_size, + ) + return row_id_map + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], +) +@triton.jit +def _permute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + # sizes + num_tokens, + num_experts, + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + cur_pos = 0 + while cur_pos < hidden_size: + cur_off = cur_pos + tl.arange(0, BLOCK_SIZE) + mask = cur_off < hidden_size + input_off = pid * stride_input_token + cur_off * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + for expert_idx in range(num_experts): + dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) + if dst_row != -1: + output_off = dst_row * stride_output_token + cur_off * stride_output_hidden + tl.store(output_ptr + output_off, inp, mask=mask) + cur_pos += BLOCK_SIZE + + +def permute_with_mask_map( + inp: torch.Tensor, + row_id_map: torch.Tensor, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +): + # pylint: disable=missing-function-docstring + output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + grid = (num_tokens,) + _permute_kernel[grid]( + inp, + output, + row_id_map, + num_tokens, + num_experts, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], +) +@triton.jit +def _unpermute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + # sizes + num_tokens, + num_experts, + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_probs_expert, + # metas + WITH_PROBS: tl.constexpr, + FP8_DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + if FP8_DTYPE == "e5m2": + compute_type = tl.float16 + data_type = tl.float8e5 + pytorch_tensor_dtype = tl.uint8 + elif FP8_DTYPE == "e4m3": + compute_type = tl.float16 + data_type = tl.float8e4nv + pytorch_tensor_dtype = tl.uint8 + else: + compute_type = input_ptr.dtype.element_ty + assert FP8_DTYPE is None + + pid = tl.program_id(0) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + for expert_idx in range(num_experts): + src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) + if src_row != -1: + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + if FP8_DTYPE is not None: + inp = inp.to(data_type, bitcast=True).to(compute_type) + if WITH_PROBS: + prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off).to(compute_type) + inp *= prob + accumulator += inp + if FP8_DTYPE is not None: + if not WITH_PROBS: + # Directly adding these value may cause overflow for fp8, we scale it here. + # The outside fp8_scale_inv is also scaled in the meantime. + accumulator /= num_experts + accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + output_off = pid * stride_output_token + current_offset * stride_output_hidden + tl.store(output_ptr + output_off, accumulator, mask=mask) + current_start += BLOCK_SIZE + + +def unpermute_with_mask_map( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: Union[torch.Tensor, None], + num_tokens: int, + num_experts: int, + hidden_size: int, + fp8_dtype: TE_DType, +): + # pylint: disable=missing-function-docstring + if fp8_dtype == TE_DType.kFloat8E5M2: + fp8_dtype = "e5m2" + elif fp8_dtype == TE_DType.kFloat8E4M3: + fp8_dtype = "e4m3" + else: + fp8_dtype = None + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + grid = (num_tokens,) + _unpermute_kernel[grid]( + inp, + output, + row_id_map, + probs, + num_tokens, + num_experts, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + probs.stride(0) if probs is not None else None, + probs.stride(1) if probs is not None else None, + WITH_PROBS=probs is not None, + FP8_DTYPE=fp8_dtype, + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], +) +@triton.jit +def _unpermute_bwd_with_probs_kernel( + # pointers + fwd_output_grad_ptr, + fwd_input_grad_ptr, + fwd_input_ptr, + probs_ptr, + probs_grad_ptr, + row_id_map_ptr, + # sizes + num_tokens, + num_experts, + hidden_size, + # strides + stride_fwd_output_grad_token, + stride_fwd_output_grad_hidden, + stride_fwd_input_grad_token, + stride_fwd_input_grad_hidden, + stride_fwd_input_token, + stride_fwd_input_hidden, + stride_probs_token, + stride_probs_expert, + stride_probs_grad_token, + stride_probs_grad_expert, + # metas + FP8_DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + if FP8_DTYPE == "e5m2": + compute_type = tl.float16 + data_type = tl.float8e5 + pytorch_tensor_dtype = tl.uint8 + elif FP8_DTYPE == "e4m3": + compute_type = tl.float16 + data_type = tl.float8e4nv + pytorch_tensor_dtype = tl.uint8 + else: + compute_type = fwd_output_grad_ptr.dtype.element_ty + assert FP8_DTYPE is None + + pid = tl.program_id(0) + for expert_idx in range(num_experts): + dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) + if dst_row != -1: + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_off = ( + pid * stride_fwd_output_grad_token + + current_offset * stride_fwd_output_grad_hidden + ) + inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) + if FP8_DTYPE is not None: + inp = inp.to(data_type, bitcast=True).to(compute_type) + probs_off = pid * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + probs_off).to(compute_type) + output = inp * prob + if FP8_DTYPE is not None: + output = output.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + output_off = ( + dst_row * stride_fwd_input_grad_token + + current_offset * stride_fwd_input_grad_hidden + ) + tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) + + fwd_input_off = ( + dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden + ) + fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) + if FP8_DTYPE is not None: + fwd_input = fwd_input.to(data_type, bitcast=True) + prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32) + current_start += BLOCK_SIZE + probs_grad = tl.sum(prob_grad_accum) + probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert + tl.store(probs_grad_ptr + probs_grad_off, probs_grad) + else: + probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert + tl.store(probs_grad_ptr + probs_grad_off, 0.0) + + +def unpermute_with_mask_map_bwd_with_probs( + fwd_output_grad: torch.Tensor, + row_id_map: torch.Tensor, + fwd_input: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, + fp8_dtype: TE_DType, +): + # pylint: disable=missing-function-docstring + if fp8_dtype == TE_DType.kFloat8E5M2: + fp8_dtype = "e5m2" + elif fp8_dtype == TE_DType.kFloat8E4M3: + fp8_dtype = "e4m3" + else: + fp8_dtype = None + act_grad = torch.empty( + (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" + ) + probs_grad = torch.empty((num_tokens, num_experts), dtype=probs.dtype, device="cuda") + grid = (num_tokens,) + _unpermute_bwd_with_probs_kernel[grid]( + fwd_output_grad, + act_grad, + fwd_input, + probs, + probs_grad, + row_id_map, + num_tokens, + num_experts, + hidden_size, + fwd_output_grad.stride(0), + fwd_output_grad.stride(1), + act_grad.stride(0), + act_grad.stride(1), + fwd_input.stride(0), + fwd_input.stride(1), + probs.stride(0), + probs.stride(1), + probs_grad.stride(0), + probs_grad.stride(1), + fp8_dtype, + ) + return act_grad, probs_grad + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], +) +@triton.jit +def _sort_chunks_by_idxs_kernel( + # pointers + input_ptr, + split_sizes_ptr, + sorted_indices_ptr, + output_ptr, + dst_rows_ptr, + # sizes + num_splits, + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + # metas + IDX_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + load_split_offset = tl.arange(0, IDX_LOAD_WIDTH) + sorted_indices = tl.load( + sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits + ) + + # get chunk idx of the current token in the input tensor + input_chunk_idx = -1 + in_chunk_offset = tl.zeros([], dtype=tl.int64) + acc_chunk_sizes = tl.zeros([], dtype=tl.int64) + cursor = 0 + while cursor < num_splits: + cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64) + acc_chunk_sizes += cur_chunk_size + if input_chunk_idx == -1 and acc_chunk_sizes > pid: + input_chunk_idx = cursor + in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size) + cursor += 1 + + # get chunk idx of the current token in the output tensor + output_chunk_idx = 0 + cursor = 0 + while cursor < num_splits: + cur_input_idx = tl.load(sorted_indices_ptr + cursor) + if cur_input_idx == input_chunk_idx: + output_chunk_idx = cursor + cursor += 1 + + # make row_id_map + output_split_sizes = tl.load( + split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits + ).to(tl.int64) + output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) + dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + tl.store(dst_rows_ptr + pid, dst_row) + + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = pid * stride_input_token + current_offset * stride_input_hidden + output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) + current_start += BLOCK_SIZE + + +def sort_chunks_by_idx( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_indices: torch.Tensor, + num_tokens: int, + hidden_size: int, + num_splits: int, +): + # pylint: disable=missing-function-docstring + row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + grid = (num_tokens,) + _sort_chunks_by_idxs_kernel[grid]( + inp, + split_sizes, + sorted_indices, + output, + row_id_map, + num_splits, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + triton.next_power_of_2(num_splits), + ) + return output, row_id_map + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], +) +@triton.jit +def _sort_chunks_by_map( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + # sizes + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + dst_row = tl.load(row_id_map_ptr + pid) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden + output_offsets = pid * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) + current_start += BLOCK_SIZE + + +def sort_chunks_by_map( + inp: torch.Tensor, + row_id_map: torch.Tensor, + num_tokens: int, + hidden_size: int, +): + # pylint: disable=missing-function-docstring + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + grid = (num_tokens,) + _sort_chunks_by_map[grid]( + inp, + output, + row_id_map, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + ) + return output From 199e6123d56d03b376c4aa483a0a51f938b1bac4 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 27 Jan 2025 16:51:16 -0800 Subject: [PATCH 062/239] Use log1p(x) instead of log(1+x) (#1401) This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html Found with TorchFix https://github.com/pytorch-labs/torchfix/ Signed-off-by: Sergii Dymchenko Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f2120f3a73..ccceacff85 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1604,7 +1604,7 @@ def flash_attn_fwd_softmax_lse_correction( """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) softmax_lse.copy_(new_scale) From 96534aa5691c90208ec4e8ffa6666e0e20758ffd Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Thu, 30 Jan 2025 08:32:07 -0800 Subject: [PATCH 063/239] Update neox to completed (#1439) Signed-off-by: Quentin Anthony --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 3f4d9bd4a3..fbcf05f3c9 100644 --- a/README.rst +++ b/README.rst @@ -264,10 +264,10 @@ Transformer Engine has been integrated with popular LLM frameworks such as: * `NVIDIA NeMo Framework `_ * `Amazon SageMaker Model Parallel Library `_ * `Levanter `_ +* `GPT-NeoX `_ * `Hugging Face Nanotron `_ - Coming soon! * `Colossal-AI `_ - Coming soon! * `PeriFlow `_ - Coming soon! -* `GPT-NeoX `_ - Coming soon! Contributing From e5369541eface67d5a76e99bfec861636c28985a Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Thu, 30 Jan 2025 17:45:17 -0800 Subject: [PATCH 064/239] Support `store_param_remainders` feature from Apex in TE Fused Adam (#1408) * Initial commit Signed-off-by: Selvaraj Anandaraj * Fixed compilation errors Signed-off-by: Selvaraj Anandaraj * Fixed syntax errors Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed NaN issue when initial param value is zero Signed-off-by: Selvaraj Anandaraj * Removed 64 bit indexing instantiation Signed-off-by: Selvaraj Anandaraj * Made this feature an opt-in Signed-off-by: Selvaraj Anandaraj * Removed arg from unscaled state Signed-off-by: Selvaraj Anandaraj * Fixed compilation error Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaned up errors Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added support for checkpointing Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed checkpointing logic Signed-off-by: Selvaraj Anandaraj * Added tests Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added assert failure for capturable mode Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed pylint errors Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fused_optimizer.py | 19 ++- transformer_engine/pytorch/csrc/extensions.h | 6 + .../multi_tensor/multi_tensor_adam.cu | 152 ++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 4 + .../pytorch/optimizers/fused_adam.py | 97 +++++++++-- 5 files changed, 264 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index be01f2c011..96acb699ad 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -184,6 +184,7 @@ def gen_precision_aware_test( grad_dtype, exp_avg_dtype, exp_avg_sq_dtype, + store_param_remainders=False, model_rtol=None, model_atol=None, master_rtol=None, @@ -220,6 +221,7 @@ def gen_precision_aware_test( "weight_decay": 0, "amsgrad": False, } + ref_optim = torch.optim.Adam(ref_params, **options) tst_optim = te.optimizers.FusedAdam( model_params, @@ -228,6 +230,7 @@ def gen_precision_aware_test( exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, use_decoupled_grad=True, + store_param_remainders=store_param_remainders, **options, ) @@ -237,7 +240,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer): p.decoupled_grad = p_ref.grad.clone().to(grad_dtype) ref_optimizer.step() tst_optimizer.step() - if use_master_weights: + if use_master_weights and not store_param_remainders: master_weights_to_fp32 = [ tst_optim.get_unscaled_state(p, "master_param") for p in model_params ] @@ -270,6 +273,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer): exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, use_decoupled_grad=True, + store_param_remainders=store_param_remainders, **options, ) tst_optim.load_state_dict(state_dict) @@ -300,6 +304,19 @@ def test_fp32_master(self): exp_avg_sq_dtype=torch.float32, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp32_master_store_param_remainders(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + store_param_remainders=True, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") def test_fp16_master(self): self.gen_precision_aware_test( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 67fd1caf5b..58527ef6d5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -479,6 +479,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay); +void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay); + void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index cb5e878fb2..548dd5a267 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -179,6 +179,122 @@ struct AdamFunctorMaster { } }; +template +struct AdamFunctorMasterParamRemainder { + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<5> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + int16_t *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + int16_t *p_remainder = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_remainder += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + union fp32_or_int162 { + float fp32; + int16_t int16[2]; + }; + fp32_or_int162 local_master_param[ILP]; + int16_t local_p[ILP]; + int16_t local_p_rem[ILP]; + MATH_T r_g[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + + local_p[ii] = static_cast(p[i]); + local_p_rem[ii] = static_cast(p_remainder[i]); + } else { + r_g[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + + local_p[ii] = int16_t(0); + local_p_rem[ii] = int16_t(0); + } + } +// Reconstruct FP32 params +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding + local_master_param[ii].int16[1] = local_p[ii]; + local_master_param[ii].int16[0] = local_p_rem[ii]; + } + + MATH_T *r_p = reinterpret_cast(local_master_param); + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } + +// Split into BF16 params (rounded-to-nearest) and remainders +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p[ii] = local_master_param[ii].int16[1]; + local_p_rem[ii] = local_master_param[ii].int16[0]; + if (local_p_rem[ii] < 0) local_p[ii]++; // Round up + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p_remainder[i] = static_cast(local_p_rem[ii]); + p[i] = static_cast(local_p[ii]); + + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, @@ -548,6 +664,42 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); } +void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + const auto p_in_type = tensor_lists[1][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 5: g, p, m, v, p_master + TORCH_CHECK(tl_size == 5, "tensor list must contain 5"); + TORCH_CHECK(p_in_type == at::ScalarType::BFloat16, + "Adam with BF16 param remainders requires BF16 params"); + + // g, p, m, v, p_master + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMasterParamRemainder(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + + AT_CUDA_CHECK(cudaGetLastError()); +} + void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 165855d430..e5d8744eef 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -213,6 +213,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam", &multi_tensor_adam_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda, + "Compute and apply gradient update to parameters for Adam optimizer" + "where the master parameters only store the remainder bits", + py::call_guard()); m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 170c95442f..b86c973304 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -94,6 +94,13 @@ class FusedAdam(torch.optim.Optimizer): instead of ".grad" for reading gradients. It's useful when the dtypes of grad and param are different. (default: False) + store_param_remainders (bool, optional): Whether to store entire FP32 master + params or just store the trailing 16 remainder bits. Whole FP32 master can be + reconstructed from BF16 params plus the trailing remainder bits. Works only + when param type is BF16 and master weight type is FP32, no effect otherwise. + Useful memory saving optimization. + (default: False) + .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -118,6 +125,7 @@ def __init__( exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32, use_decoupled_grad=False, + store_param_remainders=False, ): if amsgrad: @@ -142,6 +150,8 @@ def __init__( raise RuntimeError("Capturable mode only supports fp32 exp_avg.") if capturable and exp_avg_sq_dtype != torch.float32: raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") + if capturable and store_param_remainders: + raise RuntimeError("Capturable mode doesn't support storing param remainders") # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr @@ -172,6 +182,7 @@ def __init__( # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self.multi_tensor_adam = tex.multi_tensor_adam + self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -192,6 +203,10 @@ def __init__( } self._scales = {} self.use_decoupled_grad = use_decoupled_grad + # Works only when master params is in FP32 + self.store_param_remainders = ( + store_param_remainders and master_weights and master_weight_dtype == torch.float32 + ) def zero_grad(self): # pylint: disable=missing-function-docstring @@ -261,7 +276,14 @@ def get_unscaled_state(self, param, state_name): unscaled = state[state_name].float() unscaled.mul_(self._scales[param][state_name]) elif dtype == torch.float32: - assert state[state_name].dtype == torch.float32 + if ( + self.store_param_remainders + and state_name == "master_param" + and param.dtype == torch.bfloat16 + ): + assert state[state_name].dtype == torch.int16 + else: + assert state[state_name].dtype == torch.float32 unscaled = state[state_name] else: raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") @@ -279,10 +301,19 @@ def set_scaled_state(self, param, state_name, unscaled_state): and 'master_param`. unscaled_state (torch.Tensor): The original high-precision(FP32) state. """ - assert unscaled_state.dtype == torch.float32 + store_param_remainders = ( + self.store_param_remainders + and state_name == "master_param" + and param.dtype == torch.bfloat16 + ) + + if store_param_remainders: + assert unscaled_state.dtype == torch.int16 + else: + assert unscaled_state.dtype == torch.float32 state = self.state[param] if state_name not in state: - self._initialize_state(param, state_name, False) + self._initialize_state(param, state_name, False, store_param_remainders) dtype = self.name_to_dtype_map[state_name] if dtype != torch.float32: @@ -291,7 +322,9 @@ def set_scaled_state(self, param, state_name, unscaled_state): else: state[state_name].copy_(unscaled_state) - def _initialize_state(self, param, state_name, zero_buffer: bool): + def _initialize_state( + self, param, state_name, zero_buffer: bool, store_param_remainders: bool = False + ): """Initialize one of the optimizer states according to `state_name`. Arguments: @@ -299,9 +332,13 @@ def _initialize_state(self, param, state_name, zero_buffer: bool): state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', and 'master_param`. zero_buffer (bool): Whether to initialize the optimizer state with zeros. + store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] - data = torch.empty_like(param, dtype=dtype) + if store_param_remainders: + data = torch.zeros_like(param, dtype=torch.int16) + else: + data = torch.empty_like(param, dtype=dtype) if zero_buffer: data.zero_() @@ -322,17 +359,24 @@ def _initialize_state(self, param, state_name, zero_buffer: bool): [1], dtype=torch.float32, device=param.device ) - def initialize_state(self, param): + def initialize_state(self, param, store_param_remainders): """Initialize optimizer states. Arguments: param (torch.nn.Parameter): One of parameters in this optimizer. + store_param_remainders (bool): Store trailing remainder bits. """ self._initialize_state(param, "exp_avg", zero_buffer=True) self._initialize_state(param, "exp_avg_sq", zero_buffer=True) if self.master_weights: - self._initialize_state(param, "master_param", zero_buffer=False) - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + self._initialize_state( + param, + "master_param", + zero_buffer=False, + store_param_remainders=store_param_remainders, + ) + if not store_param_remainders: + self.set_scaled_state(param, "master_param", param.clone().detach().float()) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all @@ -377,7 +421,15 @@ def load_state_dict(self, state_dict): param = id_map[k] self.state[param] = {} for name in v: - self.set_scaled_state(param, name, v[name].float()) + if ( + self.store_param_remainders + and name == "master_param" + and param.dtype == torch.bfloat16 + ): + self.set_scaled_state(param, name, v[name]) + assert v[name].dtype == torch.int16 + else: + self.set_scaled_state(param, name, v[name].float()) def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. @@ -444,9 +496,11 @@ def step(self, closure=None, grad_scaler=None): for p in group["params"]: state = self.state[p] + store_param_remainders = self.store_param_remainders and p.dtype == torch.bfloat16 + # State initialization if len(state) == 0: - self.initialize_state(p) + self.initialize_state(p, store_param_remainders) if self.use_decoupled_grad: p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None @@ -462,8 +516,12 @@ def step(self, closure=None, grad_scaler=None): unscaled_state = {} for name in ["exp_avg", "exp_avg_sq", "master_param"]: if name in state: - unscaled = self.get_unscaled_state(p, name) - unscaled_state[name] = unscaled + if name == "master_param" and store_param_remainders: + unscaled_state[name] = self.state[p][name] + assert unscaled_state[name].dtype == torch.int16 + else: + unscaled = self.get_unscaled_state(p, name) + unscaled_state[name] = unscaled if self.name_to_dtype_map[name] != torch.float32: unscaled_lists[name].append(unscaled) scaled_lists[name].append(state[name]) @@ -506,6 +564,12 @@ def step(self, closure=None, grad_scaler=None): ) if has_fp16 and has_bf16: + if self.store_param_remainders: + raise RuntimeError( + "FusedAdam doesn't support a mix of FP16/BF16 weights + Store param" + " remainder." + ) + # simple to add support for this, but not needed for now raise RuntimeError( "FusedAdam does not support a mix of float16 and bfloat16 model weights." @@ -599,7 +663,14 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N v_of_f16_model, p_main_of_f16_model, ] - apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if self.store_param_remainders and has_bf16 and not has_fp16: + # When you have BF16 params and need FP32 master params, you can reconstruct + # the FP32 master params with BF16 params + int16 remainders + apply_multi_tensor_adam( + self.multi_tensor_adam_param_remainder, tensor_lists + ) + else: + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) if len(p_fp8_model) > 0: tensor_lists = [ g_of_fp8_model, From 544dd14b4301beb47136f273deff3f532cdde181 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 7 Feb 2025 14:17:18 -0800 Subject: [PATCH 065/239] Update main branch with TE 2.0 code, update version to 2.1.0.dev0 Signed-off-by: Przemek Tredak --- .github/workflows/build.yml | 20 - .github/workflows/lint.yml | 27 - .gitignore | 1 - 3rdparty/cudnn-frontend | 2 +- README.rst | 4 +- build_tools/VERSION.txt | 2 +- build_tools/build_ext.py | 67 +- build_tools/paddle.py | 92 - build_tools/pytorch.py | 1 - build_tools/utils.py | 16 +- build_tools/wheel_utils/build_wheels.sh | 36 - docs/api/common.rst | 2 +- docs/api/framework.rst | 1 - docs/api/paddle.rst | 34 - docs/api/pytorch.rst | 2 - docs/examples/attention/attention.ipynb | 54 +- docs/installation.rst | 2 +- examples/README.md | 5 +- examples/paddle/mnist/README.md | 7 - .../paddle/mnist/test_single_gpu_mnist.py | 291 -- pylintrc | 1 - qa/L0_jax_unittest/test.sh | 2 +- qa/L0_paddle_lint/test.sh | 24 - qa/L0_paddle_unittest/test.sh | 10 - qa/L0_paddle_wheel/test.sh | 37 - qa/L0_pytorch_unittest/test.sh | 3 +- qa/L1_pytorch_distributed_unittest/test.sh | 4 +- qa/L1_pytorch_onnx_test/test.sh | 16 - qa/L3_pytorch_FA_versions_test/test.sh | 13 +- setup.py | 25 +- tests/cpp/CMakeLists.txt | 6 +- tests/cpp/operator/CMakeLists.txt | 18 +- tests/cpp/operator/test_act.cu | 88 +- tests/cpp/operator/test_cast.cu | 130 + tests/cpp/operator/test_cast_dbias.cu | 181 ++ tests/cpp/operator/test_cast_dbias_dgelu.cu | 196 ++ tests/cpp/operator/test_cast_gated_swiglu.cu | 165 ++ tests/cpp/operator/test_cast_mxfp8.cu | 636 ++++ .../operator/test_cast_mxfp8_gated_swiglu.cu | 470 +++ tests/cpp/operator/test_cast_transpose.cu | 26 +- .../cpp/operator/test_cast_transpose_dbias.cu | 53 +- .../test_cast_transpose_dbias_dgelu.cu | 39 +- .../operator/test_cast_transpose_dgeglu.cu | 26 +- tests/cpp/operator/test_causal_softmax.cu | 18 +- tests/cpp/operator/test_dequantize_mxfp8.cu | 452 +++ .../cpp/operator/test_multi_cast_transpose.cu | 37 +- tests/cpp/operator/test_multi_padding.cu | 10 +- tests/cpp/operator/test_normalization.cu | 107 +- .../cpp/operator/test_normalization_mxfp8.cu | 337 +++ tests/cpp/operator/test_qdq.cu | 22 +- tests/cpp/operator/test_swizzle.cu | 165 ++ tests/cpp/operator/test_transpose.cu | 8 +- tests/cpp/test_common.cu | 670 ++++- tests/cpp/test_common.h | 345 ++- tests/cpp/util/CMakeLists.txt | 7 +- tests/jax/conftest.py | 3 - tests/jax/test_layer.py | 39 +- tests/jax/utils.py | 23 +- tests/paddle/dist_launcher.py | 145 - tests/paddle/parallel_tests/amax_reduction.py | 87 - tests/paddle/parallel_tests/attention_tp.py | 234 -- tests/paddle/parallel_tests/group_sharding.py | 188 -- .../parallel_tests/layernorm_linear_tp.py | 182 -- .../paddle/parallel_tests/layernorm_mlp_tp.py | 197 -- tests/paddle/parallel_tests/linear_pp.py | 235 -- tests/paddle/parallel_tests/linear_tp.py | 222 -- tests/paddle/parallel_tests/transformer_tp.py | 250 -- .../recompute_transformer_encoder.py | 71 - tests/paddle/test_install.py | 11 - tests/paddle/test_layers.py | 1663 ----------- tests/paddle/test_master_grad.py | 92 - tests/paddle/test_operators.py | 1201 -------- tests/paddle/test_parallel.py | 99 - tests/paddle/test_recompute.py | 56 - tests/paddle/utils.py | 221 -- tests/pytorch/custom_ort_ops/.gitignore | 3 - tests/pytorch/custom_ort_ops/CMakeLists.txt | 29 - tests/pytorch/custom_ort_ops/README.md | 22 - tests/pytorch/custom_ort_ops/build.sh | 17 - .../custom_ort_ops/custom_op_library.cc | 102 - .../distributed/run_gemm_with_overlap.py | 284 +- .../distributed/run_layer_with_overlap.py | 94 +- tests/pytorch/distributed/run_numerics.py | 81 +- .../distributed/test_comm_gemm_overlap.py | 181 +- tests/pytorch/distributed/test_fusible_ops.py | 172 +- tests/pytorch/distributed/test_numerics.py | 22 +- tests/pytorch/distributed/test_torch_fsdp2.py | 45 +- .../fused_attn/run_fused_attn_with_cp.py | 11 +- tests/pytorch/fused_attn/test_fused_attn.py | 404 +-- tests/pytorch/test_cpu_offloading.py | 57 + tests/pytorch/test_cuda_graphs.py | 22 +- tests/pytorch/test_float8tensor.py | 165 +- tests/pytorch/test_fused_optimizer.py | 3 +- tests/pytorch/test_fusible_ops.py | 577 ++-- tests/pytorch/test_numerics.py | 241 +- tests/pytorch/test_onnx_export.py | 1562 ---------- tests/pytorch/test_permutation.py | 38 +- tests/pytorch/test_recipe.py | 74 +- tests/pytorch/test_sanity.py | 100 +- tests/pytorch/test_torch_save_load.py | 474 --- transformer_engine/__init__.py | 10 - transformer_engine/common/CMakeLists.txt | 9 +- .../common/activation/activation_template.h | 130 +- transformer_engine/common/activation/gelu.cu | 29 +- transformer_engine/common/activation/relu.cu | 28 +- .../common/activation/swiglu.cu | 14 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 471 +-- .../userbuffers/userbuffers.cu | 1 + transformer_engine/common/common.cu | 121 +- transformer_engine/common/common.h | 207 +- .../common/fused_attn/fused_attn.cpp | 55 +- .../fused_attn_f16_arbitrary_seqlen.cu | 6 +- .../common/fused_attn/fused_attn_fp8.cu | 215 +- .../common/gemm/cublaslt_gemm.cu | 215 +- .../include/transformer_engine/activation.h | 165 +- .../common/include/transformer_engine/cast.h | 199 +- .../transformer_engine/cast_transpose_noop.h | 19 +- .../transformer_engine/comm_gemm_overlap.h | 155 +- .../include/transformer_engine/recipe.h | 19 +- .../include/transformer_engine/swizzle.h | 37 + .../transformer_engine/transformer_engine.h | 246 +- .../include/transformer_engine/transpose.h | 291 +- .../common/normalization/common.cpp | 166 +- .../common/normalization/common.h | 44 +- .../common/normalization/layernorm/ln_api.cpp | 39 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 46 +- transformer_engine/common/recipe/__init__.py | 52 +- .../common/recipe/delayed_scaling.cu | 100 +- transformer_engine/common/swizzle/swizzle.cu | 338 +++ .../common/transformer_engine.cpp | 334 ++- .../common/transpose/cast_transpose.cu | 256 +- .../common/transpose/cast_transpose.h | 28 + .../common/transpose/cast_transpose_fusion.cu | 418 +-- .../common/transpose/multi_cast_transpose.cu | 68 +- .../transpose/rtc/cast_transpose_fusion.cu | 29 +- .../common/transpose/transpose.cu | 11 +- .../common/transpose/transpose_fusion.cu | 31 +- transformer_engine/common/util/cast.cu | 180 +- .../common/util/cast_gated_kernels.cuh | 1091 +++++++ .../common/util/cast_kernels.cuh | 1251 ++++++++ .../common/util/cuda_runtime.cpp | 20 + transformer_engine/common/util/cuda_runtime.h | 10 + .../common/util/dequantize_kernels.cuh | 360 +++ transformer_engine/common/util/ptx.cuh | 300 ++ .../common/util/pybind_helper.h | 152 +- transformer_engine/common/util/system.h | 2 - .../common/util/vectorized_pointwise.h | 112 +- transformer_engine/common/utils.cuh | 111 + .../jax/csrc/extensions/activation.cpp | 112 +- .../jax/csrc/extensions/quantization.cpp | 8 +- .../jax/csrc/extensions/transpose.cpp | 53 +- transformer_engine/jax/fp8.py | 5 - transformer_engine/paddle/MANIFEST.in | 3 - transformer_engine/paddle/__init__.py | 60 - transformer_engine/paddle/constants.py | 74 - transformer_engine/paddle/cpp_extensions.py | 1199 -------- transformer_engine/paddle/csrc/common.cpp | 84 - transformer_engine/paddle/csrc/common.h | 185 -- transformer_engine/paddle/csrc/custom_ops.cu | 1776 ----------- transformer_engine/paddle/csrc/extensions.cpp | 63 - transformer_engine/paddle/distributed.py | 213 -- transformer_engine/paddle/fp8.py | 370 --- transformer_engine/paddle/fp8_buffer.py | 350 --- transformer_engine/paddle/layer/__init__.py | 12 - transformer_engine/paddle/layer/attention.py | 1161 -------- transformer_engine/paddle/layer/base.py | 571 ---- transformer_engine/paddle/layer/layernorm.py | 197 -- .../paddle/layer/layernorm_linear.py | 721 ----- .../paddle/layer/layernorm_mlp.py | 1010 ------- transformer_engine/paddle/layer/linear.py | 919 ------ transformer_engine/paddle/layer/rmsnorm.py | 175 -- transformer_engine/paddle/layer/softmax.py | 254 -- .../paddle/layer/transformer.py | 375 --- transformer_engine/paddle/profile.py | 19 - transformer_engine/paddle/recompute.py | 63 - transformer_engine/paddle/setup.py | 64 - transformer_engine/paddle/utils.py | 149 - transformer_engine/pytorch/__init__.py | 15 - transformer_engine/pytorch/attention.py | 2608 ++++++----------- transformer_engine/pytorch/constants.py | 4 + .../pytorch/cpp_extensions/__init__.py | 5 - .../pytorch/cpp_extensions/_common.py | 87 - .../pytorch/cpp_extensions/activation.py | 237 -- .../pytorch/cpp_extensions/cast.py | 93 - .../pytorch/cpp_extensions/fused_attn.py | 970 +----- .../pytorch/cpp_extensions/gemm.py | 544 +--- .../pytorch/cpp_extensions/normalization.py | 260 -- .../pytorch/cpp_extensions/padding.py | 29 - .../pytorch/cpp_extensions/transpose.py | 230 -- transformer_engine/pytorch/cpu_offload.py | 18 +- transformer_engine/pytorch/csrc/common.cpp | 148 +- transformer_engine/pytorch/csrc/common.h | 169 +- transformer_engine/pytorch/csrc/extensions.h | 518 +--- .../pytorch/csrc/extensions/activation.cpp | 298 +- .../pytorch/csrc/extensions/apply_rope.cpp | 8 +- .../pytorch/csrc/extensions/attention.cu | 965 +----- .../pytorch/csrc/extensions/bias.cpp | 51 + .../pytorch/csrc/extensions/cast.cpp | 147 +- .../csrc/extensions/comm_gemm_overlap.cpp | 431 +-- .../pytorch/csrc/extensions/gemm.cpp | 488 ++- .../pytorch/csrc/extensions/normalization.cpp | 295 +- .../pytorch/csrc/extensions/padding.cpp | 1 + .../pytorch/csrc/extensions/permutation.cu | 3 + .../pytorch/csrc/extensions/pybind.cpp | 348 +-- .../pytorch/csrc/extensions/quantizer.cpp | 227 ++ .../pytorch/csrc/extensions/recipe.cpp | 23 +- .../pytorch/csrc/extensions/softmax.cpp | 16 +- .../pytorch/csrc/extensions/swizzle.cpp | 120 + .../pytorch/csrc/extensions/transpose.cpp | 482 +-- .../csrc/extensions/type_converters.cpp | 79 + .../pytorch/csrc/extensions/util.cpp | 14 +- transformer_engine/pytorch/csrc/pybind.h | 73 + transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 414 --- transformer_engine/pytorch/csrc/util.h | 12 + transformer_engine/pytorch/distributed.py | 252 +- transformer_engine/pytorch/export.py | 40 - transformer_engine/pytorch/float8_tensor.py | 2 +- transformer_engine/pytorch/fp8.py | 238 +- transformer_engine/pytorch/graph.py | 16 +- transformer_engine/pytorch/module/_common.py | 150 +- transformer_engine/pytorch/module/base.py | 385 +-- .../pytorch/module/fp8_padding.py | 7 +- .../pytorch/module/fp8_unpadding.py | 9 +- .../pytorch/module/grouped_linear.py | 528 ++-- .../pytorch/module/layernorm_linear.py | 1040 +++---- .../pytorch/module/layernorm_mlp.py | 1581 +++++----- transformer_engine/pytorch/module/linear.py | 1017 +++---- transformer_engine/pytorch/ops/_common.py | 53 +- .../pytorch/ops/basic/activation.py | 161 +- .../pytorch/ops/basic/all_gather.py | 56 +- .../pytorch/ops/basic/basic_linear.py | 1024 +++---- .../pytorch/ops/basic/layer_norm.py | 76 +- .../pytorch/ops/basic/quantize.py | 30 +- .../pytorch/ops/basic/reduce_scatter.py | 52 +- .../pytorch/ops/basic/reshape.py | 5 +- .../pytorch/ops/basic/rmsnorm.py | 72 +- .../pytorch/ops/fused/backward_linear_add.py | 12 +- .../fused/forward_linear_bias_activation.py | 47 +- .../ops/fused/forward_linear_bias_add.py | 43 +- .../ops/fused/userbuffers_backward_linear.py | 13 +- .../ops/fused/userbuffers_forward_linear.py | 9 +- transformer_engine/pytorch/ops/op.py | 266 +- .../pytorch/optimizers/fused_adam.py | 38 +- transformer_engine/pytorch/permutation.py | 33 +- transformer_engine/pytorch/setup.py | 5 +- transformer_engine/pytorch/softmax.py | 155 +- .../pytorch/te_onnx_extensions.py | 519 ---- transformer_engine/pytorch/tensor/__init__.py | 18 +- .../pytorch/tensor/_internal/__init__.py | 5 +- .../tensor/_internal/float8_tensor_base.py | 139 + .../tensor/_internal/mxfp8_tensor_base.py | 136 + .../pytorch/tensor/float8_tensor.py | 1157 +++----- .../pytorch/tensor/mxfp8_tensor.py | 582 ++++ .../pytorch/tensor/quantized_tensor.py | 322 +- transformer_engine/pytorch/transformer.py | 4 +- transformer_engine/pytorch/utils.py | 47 +- 256 files changed, 20152 insertions(+), 34070 deletions(-) delete mode 100644 build_tools/paddle.py delete mode 100644 docs/api/paddle.rst delete mode 100644 examples/paddle/mnist/README.md delete mode 100644 examples/paddle/mnist/test_single_gpu_mnist.py delete mode 100644 qa/L0_paddle_lint/test.sh delete mode 100644 qa/L0_paddle_unittest/test.sh delete mode 100644 qa/L0_paddle_wheel/test.sh delete mode 100644 qa/L1_pytorch_onnx_test/test.sh create mode 100644 tests/cpp/operator/test_cast.cu create mode 100644 tests/cpp/operator/test_cast_dbias.cu create mode 100644 tests/cpp/operator/test_cast_dbias_dgelu.cu create mode 100644 tests/cpp/operator/test_cast_gated_swiglu.cu create mode 100644 tests/cpp/operator/test_cast_mxfp8.cu create mode 100644 tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu create mode 100644 tests/cpp/operator/test_dequantize_mxfp8.cu create mode 100644 tests/cpp/operator/test_normalization_mxfp8.cu create mode 100644 tests/cpp/operator/test_swizzle.cu delete mode 100644 tests/paddle/dist_launcher.py delete mode 100644 tests/paddle/parallel_tests/amax_reduction.py delete mode 100644 tests/paddle/parallel_tests/attention_tp.py delete mode 100644 tests/paddle/parallel_tests/group_sharding.py delete mode 100644 tests/paddle/parallel_tests/layernorm_linear_tp.py delete mode 100644 tests/paddle/parallel_tests/layernorm_mlp_tp.py delete mode 100644 tests/paddle/parallel_tests/linear_pp.py delete mode 100644 tests/paddle/parallel_tests/linear_tp.py delete mode 100644 tests/paddle/parallel_tests/transformer_tp.py delete mode 100644 tests/paddle/recompute_tests/recompute_transformer_encoder.py delete mode 100644 tests/paddle/test_install.py delete mode 100644 tests/paddle/test_layers.py delete mode 100644 tests/paddle/test_master_grad.py delete mode 100644 tests/paddle/test_operators.py delete mode 100644 tests/paddle/test_parallel.py delete mode 100644 tests/paddle/test_recompute.py delete mode 100644 tests/paddle/utils.py delete mode 100644 tests/pytorch/custom_ort_ops/.gitignore delete mode 100644 tests/pytorch/custom_ort_ops/CMakeLists.txt delete mode 100644 tests/pytorch/custom_ort_ops/README.md delete mode 100644 tests/pytorch/custom_ort_ops/build.sh delete mode 100755 tests/pytorch/custom_ort_ops/custom_op_library.cc create mode 100644 tests/pytorch/test_cpu_offloading.py delete mode 100644 tests/pytorch/test_onnx_export.py delete mode 100644 tests/pytorch/test_torch_save_load.py create mode 100644 transformer_engine/common/include/transformer_engine/swizzle.h create mode 100644 transformer_engine/common/swizzle/swizzle.cu create mode 100644 transformer_engine/common/transpose/cast_transpose.h create mode 100644 transformer_engine/common/util/cast_gated_kernels.cuh create mode 100644 transformer_engine/common/util/cast_kernels.cuh create mode 100644 transformer_engine/common/util/dequantize_kernels.cuh create mode 100644 transformer_engine/common/util/ptx.cuh delete mode 100644 transformer_engine/paddle/MANIFEST.in delete mode 100644 transformer_engine/paddle/__init__.py delete mode 100644 transformer_engine/paddle/constants.py delete mode 100644 transformer_engine/paddle/cpp_extensions.py delete mode 100644 transformer_engine/paddle/csrc/common.cpp delete mode 100644 transformer_engine/paddle/csrc/common.h delete mode 100644 transformer_engine/paddle/csrc/custom_ops.cu delete mode 100644 transformer_engine/paddle/csrc/extensions.cpp delete mode 100644 transformer_engine/paddle/distributed.py delete mode 100644 transformer_engine/paddle/fp8.py delete mode 100644 transformer_engine/paddle/fp8_buffer.py delete mode 100644 transformer_engine/paddle/layer/__init__.py delete mode 100644 transformer_engine/paddle/layer/attention.py delete mode 100644 transformer_engine/paddle/layer/base.py delete mode 100644 transformer_engine/paddle/layer/layernorm.py delete mode 100644 transformer_engine/paddle/layer/layernorm_linear.py delete mode 100644 transformer_engine/paddle/layer/layernorm_mlp.py delete mode 100644 transformer_engine/paddle/layer/linear.py delete mode 100644 transformer_engine/paddle/layer/rmsnorm.py delete mode 100644 transformer_engine/paddle/layer/softmax.py delete mode 100644 transformer_engine/paddle/layer/transformer.py delete mode 100644 transformer_engine/paddle/profile.py delete mode 100644 transformer_engine/paddle/recompute.py delete mode 100644 transformer_engine/paddle/setup.py delete mode 100644 transformer_engine/paddle/utils.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/_common.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/activation.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/cast.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/normalization.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/padding.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/transpose.py create mode 100644 transformer_engine/pytorch/csrc/extensions/bias.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/quantizer.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/swizzle.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/type_converters.cpp rename tests/pytorch/custom_ort_ops/custom_op_library.h => transformer_engine/pytorch/csrc/extensions/util.cpp (53%) mode change 100755 => 100644 create mode 100644 transformer_engine/pytorch/csrc/pybind.h delete mode 100644 transformer_engine/pytorch/csrc/ts_fp8_op.cpp create mode 100644 transformer_engine/pytorch/csrc/util.h delete mode 100755 transformer_engine/pytorch/export.py delete mode 100755 transformer_engine/pytorch/te_onnx_extensions.py rename tests/paddle/test_sanity_import.py => transformer_engine/pytorch/tensor/_internal/__init__.py (69%) create mode 100644 transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/mxfp8_tensor.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 964e71fa8c..4be7a30a86 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -73,23 +73,3 @@ jobs: MAX_JOBS: 1 - name: 'Sanity check' run: python tests/jax/test_sanity_import.py - paddle: - name: 'PaddlePaddle' - runs-on: ubuntu-latest - container: - image: nvcr.io/nvidia/paddlepaddle:24.10-py3 - options: --user root - steps: - - name: 'Checkout' - uses: actions/checkout@v3 - with: - submodules: recursive - - name: 'Build' - run: | - apt-get update - apt-get install -y libgoogle-glog-dev - pip install . -v - env: - NVTE_FRAMEWORK: paddle - - name: 'Sanity check' - run: python tests/paddle/test_sanity_import.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f98fc9aa3a..ee6433d484 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -61,30 +61,3 @@ jobs: export PYTHON_ONLY=1 export TE_PATH=. bash ./qa/L0_jax_lint/test.sh - paddle_cpplint: - name: 'PaddlePaddle C++' - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: 'Lint' - run: | - sudo apt-get update - sudo apt-get install pip -y - export CPP_ONLY=1 - export TE_PATH=. - bash ./qa/L0_paddle_lint/test.sh - paddle_pylint: - name: 'PaddlePaddle Python' - runs-on: ubuntu-latest - steps: - - name: 'Checkout' - uses: actions/checkout@v3 - - name: 'Lint' - run: | - sudo apt-get update - sudo apt-get install pip -y - pip install paddlepaddle-gpu - export PYTHON_ONLY=1 - export TE_PATH=. - bash ./qa/L0_paddle_lint/test.sh diff --git a/.gitignore b/.gitignore index 9b61454e21..f491b21f43 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ *.nsys-rep *.ncu-rep *.sqlite -*.onnx *.eggs build/ *.so diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index cc5632eda7..91b7532f33 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699 +Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a diff --git a/README.rst b/README.rst index fbcf05f3c9..8fea8c9d94 100644 --- a/README.rst +++ b/README.rst @@ -174,7 +174,7 @@ To install the latest stable version of Transformer Engine, pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle). +This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g. @@ -182,7 +182,7 @@ Alternatively, the package can be directly installed from `Transformer Engine's pip install transformer_engine[pytorch] -To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. From source ^^^^^^^^^^^ diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 809a0327d8..eb5820cd2d 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.14.0.dev0 +2.1.0.dev0 diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 5744439c1b..a3243d087b 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -129,63 +129,6 @@ def run(self) -> None: super().run() self.extensions = all_extensions - paddle_ext = None - if "paddle" in get_frameworks(): - for ext in self.extensions: - if "paddle" in ext.name: - paddle_ext = ext - break - - # Manually write stub file for Paddle extension - if paddle_ext is not None: - # Load libtransformer_engine.so to avoid linker errors - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Source compilation from top-level (--editable) - search_paths = list(Path(__file__).resolve().parent.parent.iterdir()) - # Source compilation from top-level - search_paths.extend(list(Path(self.build_lib).iterdir())) - - # Dynamically load required_libs. - from transformer_engine.common import _load_cudnn, _load_nvrtc - - _load_cudnn() - _load_nvrtc() - else: - # Only during release bdist build for paddlepaddle. - import transformer_engine - - search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) - del transformer_engine - - common_so_path = "" - for path in search_paths: - if path.name.startswith("libtransformer_engine."): - common_so_path = str(path) - assert common_so_path, "Could not find libtransformer_engine" - ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL) - - # Figure out stub file path - module_name = paddle_ext.name - assert module_name.endswith( - "_pd_" - ), "Expected Paddle extension module to end with '_pd_'" - stub_name = module_name[:-4] # remove '_pd_' - stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py") - Path(stub_path).parent.mkdir(exist_ok=True, parents=True) - - # Figure out library name - # Note: This library doesn't actually exist. Paddle - # internally reinserts the '_pd_' suffix. - so_path = self.get_ext_fullpath(module_name) - _, so_ext = os.path.splitext(so_path) - lib_name = stub_name + so_ext - - # Write stub file - print(f"Writing Paddle stub for {lib_name} into file {stub_path}") - from paddle.utils.cpp_extension.extension_utils import custom_write_stub - - custom_write_stub(lib_name, stub_path) - # Ensure that binaries are not in global package space. target_dir = install_dir / "transformer_engine" target_dir.mkdir(exist_ok=True, parents=True) @@ -194,16 +137,10 @@ def run(self) -> None: self.copy_file(ext, target_dir) os.remove(ext) - # For paddle, the stub file needs to be copied to the install location. - if paddle_ext is not None: - stub_path = Path(self.build_lib) / "transformer_engine" - for stub in stub_path.glob("transformer_engine_paddle.py"): - self.copy_file(stub, target_dir) - def build_extensions(self): - # BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly + # BuildExtensions from PyTorch already handle CUDA files correctly # so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed. - if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks(): + if "pytorch" not in get_frameworks(): # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # extra_compile_args is a dict. for ext in self.extensions: diff --git a/build_tools/paddle.py b/build_tools/paddle.py deleted file mode 100644 index f0fcdb8f25..0000000000 --- a/build_tools/paddle.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Paddle-paddle related extensions.""" -from pathlib import Path - -import setuptools -import os - -from .utils import cuda_version - -import paddle - -paddle_version = paddle.__version__.replace(".", "") - - -def setup_paddle_extension( - csrc_source_files, - csrc_header_files, - common_header_files, -) -> setuptools.Extension: - """Setup CUDA extension for Paddle support""" - - # Source files - csrc_source_files = Path(csrc_source_files) - sources = [ - csrc_source_files / "extensions.cpp", - csrc_source_files / "common.cpp", - csrc_source_files / "custom_ops.cu", - ] - - # Header files - include_dirs = [ - common_header_files, - common_header_files / "common", - common_header_files / "common" / "include", - csrc_header_files, - ] - - # Compiler flags - cxx_flags = ["-O3"] - nvcc_flags = [ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - f"-DPADDLE_VERSION={paddle_version}", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - - # Version-dependent CUDA options - try: - version = cuda_version() - except FileNotFoundError: - print("Could not determine CUDA Toolkit version") - else: - if version < (12, 0): - raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - nvcc_flags.extend( - ( - "--threads", - os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), - "-gencode", - "arch=compute_80,code=sm_80", - "-gencode", - "arch=compute_90,code=sm_90", - ) - ) - - # Construct Paddle CUDA extension - sources = [str(path) for path in sources] - include_dirs = [str(path) for path in include_dirs] - from paddle.utils.cpp_extension import CUDAExtension - - ext = CUDAExtension( - sources=sources, - include_dirs=include_dirs, - extra_compile_args={ - "cxx": cxx_flags, - "nvcc": nvcc_flags, - }, - ) - ext.name = "transformer_engine_paddle_pd_" - return ext diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index f060e99dff..b8501e1008 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -27,7 +27,6 @@ def setup_pytorch_extension( extensions_dir = csrc_source_files / "extensions" sources = [ csrc_source_files / "common.cpp", - csrc_source_files / "ts_fp8_op.cpp", ] + all_files_in_dir(extensions_dir) # Header files diff --git a/build_tools/utils.py b/build_tools/utils.py index f2a4200685..723f2f200c 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -190,7 +190,12 @@ def cuda_path() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90") + version = cuda_version() + if os.getenv("NVTE_CUDA_ARCHS") is None: + os.environ["NVTE_CUDA_ARCHS"] = ( + "70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90" + ) + return os.getenv("NVTE_CUDA_ARCHS") def cuda_version() -> Tuple[int, ...]: @@ -211,7 +216,7 @@ def cuda_version() -> Tuple[int, ...]: def get_frameworks() -> List[str]: """DL frameworks to build support for""" _frameworks: List[str] = [] - supported_frameworks = ["pytorch", "jax", "paddle"] + supported_frameworks = ["pytorch", "jax"] # Check environment variable if os.getenv("NVTE_FRAMEWORK"): @@ -237,12 +242,6 @@ def get_frameworks() -> List[str]: pass else: _frameworks.append("jax") - try: - import paddle - except ImportError: - pass - else: - _frameworks.append("paddle") # Special framework names if "all" in _frameworks: @@ -311,7 +310,6 @@ def uninstall_te_wheel_packages(): "-y", "transformer_engine_cu12", "transformer_engine_torch", - "transformer_engine_paddle", "transformer_engine_jax", ] ) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index ceebe626f4..9acb22aee6 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -9,7 +9,6 @@ BUILD_METAPACKAGE=${2:-true} BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} BUILD_JAX=${5:-true} -BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -63,38 +62,3 @@ if $BUILD_JAX ; then /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi - -if $BUILD_PADDLE ; then - if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then - dnf -y remove --allowerasing cudnn9-cuda-12 - dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 - cd /TransformerEngine/transformer_engine/paddle - - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - mv dist/* /wheelhouse/ - fi -fi diff --git a/docs/api/common.rst b/docs/api/common.rst index 85201aee5d..5e0a660ae6 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -8,4 +8,4 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.Format -.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False)) +.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) diff --git a/docs/api/framework.rst b/docs/api/framework.rst index acd54fe3b1..0ac1a0e34e 100644 --- a/docs/api/framework.rst +++ b/docs/api/framework.rst @@ -10,4 +10,3 @@ Framework-specific API pytorch jax - paddle diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst deleted file mode 100644 index 3b3ecf55c6..0000000000 --- a/docs/api/paddle.rst +++ /dev/null @@ -1,34 +0,0 @@ -.. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -paddle -====== - -.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs) - -.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) - :members: forward - -.. autoapifunction:: transformer_engine.paddle.fp8_autocast - -.. autoapifunction:: transformer_engine.paddle.recompute diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 43001feeb3..6d5fe6761d 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -42,8 +42,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.checkpoint -.. autoapifunction:: transformer_engine.pytorch.onnx_export - .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 27017b4773..16a3b05466 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -14,11 +14,10 @@ "
Figure 1: Dot product attention.
\n", "\n", "\n", - "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n", + "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n", "\n", "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", - "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n", - "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)" + "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)" ] }, { @@ -56,15 +55,6 @@ " \n", " JAX-native attention (`_UnfusedDotProductAttention`)\n", " \n", - " \n", - " PaddlePaddle\n", - " cuDNN attention (`_te_forward`) \n", - " [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n", - " \n", - " \n", - " \n", - " PaddlePaddle-native attention (`_pd_forward`)\n", - " \n", " \n", "" ] @@ -87,7 +77,7 @@ "
\n", "Note: \n", " \n", - "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", + "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n", "
\n" ] }, @@ -102,13 +92,13 @@ "\n", "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n", "\n", - "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", + "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "\n", "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "\n", "### 1.3 cuDNN Attention\n", "\n", - "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", + "The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", "\n", "\n", " \n", @@ -153,9 +143,9 @@ " \n", "
\n", "\n", - "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n", + "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n", "\n", - "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n", + "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n", "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n", "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n", "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", @@ -244,10 +234,6 @@ " JAX\n", " cuDNN attention > JAX-native attention\n", " \n", - " \n", - " PaddlePaddle\n", - " cuDNN attention > PaddlePaddle-native attention \n", - " \n", "" ] }, @@ -266,7 +252,7 @@ "
\n", "Note:\n", " \n", - "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n", + "These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n", "
" ] }, @@ -382,7 +368,7 @@ "
\n", "Note\n", " \n", - "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n", + "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX in the future.\n", "
\n", "\n", "### 2.3 Example Tests\n", @@ -399,7 +385,7 @@ "source": [ "## 3. Backend Support\n", "\n", - "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n", + "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n", "\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", @@ -442,7 +428,7 @@ "**qkv_layout=thd_thd_thd:**\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "\n", - "As of v1.10, Transformer Engine has the following support matrix.\n", + "As of v2.0, Transformer Engine has the following support matrix.\n", "\n", "\n", " \n", @@ -462,13 +448,13 @@ " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", "
\n", - " JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", + " JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", "
Framework-native attention`bshd`, `sbhd`PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layoutsPyTorch, JAX: 2 formats, i.e. 10 layouts
\n", "\n", @@ -492,7 +478,7 @@ "\n", "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n", "\n", - "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n", + "Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n", "\n", "\n", " \n", @@ -512,21 +498,21 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", "
Framework-native attention
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax, PaddlePaddle)
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax)
  • \n", "\n", - "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", + "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n", "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", " - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n", "\n", "\n", - "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", + "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", @@ -566,7 +552,7 @@ "\n", "### 3.3 Attention Bias\n", "\n", - "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n", + "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n", "\n", "\n", " \n", @@ -591,7 +577,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -620,7 +606,7 @@ "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", diff --git a/docs/installation.rst b/docs/installation.rst index fae01c64fa..ee7afa9006 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -37,7 +37,7 @@ Transformer Engine can be directly installed from `our PyPI desired_test_accuracy - - @unittest.skipIf( - paddle.device.cuda.get_device_capability() < (8, 0), - "BF16 MNIST example requires Ampere+ GPU", - ) - def test_te_bf16(self): - """Test Transformer Engine with BF16""" - self.args.use_te = True - self.args.use_fp8 = False - self.args.save_model = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - @unittest.skipIf(not gpu_has_fp8, reason) - def test_te_fp8(self): - """Test Transformer Engine with FP8""" - self.args.use_te = True - self.args.use_fp8 = True - self.args.save_model = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - @unittest.skipIf(not gpu_has_fp8, reason) - def test_te_fp8_calibration(self): - """Test Transformer Engine with FP8 calibration""" - self.args.use_te = True - self.args.use_fp8 = False - self.args.use_fp8_infer = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - -if __name__ == "__main__": - train_and_evaluate(mnist_parser(None)) diff --git a/pylintrc b/pylintrc index b80679d72c..4af0c6b427 100644 --- a/pylintrc +++ b/pylintrc @@ -2,7 +2,6 @@ extension-pkg-whitelist=flash_attn_2_cuda, torch, transformer_engine_torch, - transformer_engine_paddle, transformer_engine_jax extension-pkg-allow-list=transformer_engine.transformer_engine_jax diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 6eff047721..8e2e540293 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -8,7 +8,7 @@ pip install "nltk>=3.8.2" pip install pytest==8.2.1 : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py # Test without custom calls NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh deleted file mode 100644 index 1c26bd265b..0000000000 --- a/qa/L0_paddle_lint/test.sh +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: "${TE_PATH:=/opt/transformerengine}" - -pip install cpplint==1.6.0 pylint==3.3.1 -if [ -z "${PYTHON_ONLY}" ] -then - cd $TE_PATH - echo "Checking common API headers" - cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include - echo "Checking C++ files" - cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common - cpplint --recursive transformer_engine/paddle -fi -if [ -z "${CPP_ONLY}" ] -then - cd $TE_PATH - echo "Checking Python files" - pylint --recursive=y transformer_engine/common transformer_engine/paddle -fi diff --git a/qa/L0_paddle_unittest/test.sh b/qa/L0_paddle_unittest/test.sh deleted file mode 100644 index 9312f22ba4..0000000000 --- a/qa/L0_paddle_unittest/test.sh +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -xe - -pip install pytest==8.2.1 -: ${TE_PATH:=/opt/transformerengine} -pytest -Wignore -v $TE_PATH/tests/paddle -pytest -Wignore -v $TE_PATH/examples/paddle/mnist diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh deleted file mode 100644 index 5116bdb5cf..0000000000 --- a/qa/L0_paddle_wheel/test.sh +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: "${TE_PATH:=/opt/transformerengine}" - -# Install dependencies -# Note: Need to install wheel locally since PaddlePaddle container -# already contains APT install. -pip install pydantic -pip install --user wheel==0.44.0 - -cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle - -VERSION=`cat $TE_PATH/build_tools/VERSION.txt` -WHL_BASE="transformer_engine-${VERSION}" - -# Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -python -m wheel unpack dist/* -sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -python -m wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel -pip install dist/*.whl --no-deps - -cd transformer_engine/paddle -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -pip install dist/* - -python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 793fa47259..dd7f95bce0 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index ee7c28ca5f..8ee0be1af5 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -8,8 +8,8 @@ set -e pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py +# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh deleted file mode 100644 index 8e4ef03b8e..0000000000 --- a/qa/L1_pytorch_onnx_test/test.sh +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: ${TE_PATH:=/opt/transformerengine} - -pip install pytest==8.2.1 onnxruntime==1.19.2 - -# Build custom ONNX Runtime operators -export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops -bash $CUSTOM_ORT_OPS_PATH/build.sh - -# Run tests -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index e63ba358a5..8ed3002214 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -12,7 +12,14 @@ pip install pytest==8.2.1 export MAX_JOBS=4 # Iterate over Flash Attention versions -FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1) +sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` +if [ $sm_arch -gt 90 ] +then + FA_versions=(2.7.3) +else + FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) +fi + for fa_version in "${FA_versions[@]}" do @@ -21,10 +28,10 @@ do then pip install flash-attn==${fa_version} else - pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" + pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" python_path=`python -c "import site; print(site.getsitepackages()[0])"` mkdir -p $python_path/flashattn_hopper - wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py + wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py fi # Run tests diff --git a/setup.py b/setup.py index 643dd7a908..1d9818458e 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ """Installation script.""" import os +import sys import time from pathlib import Path from typing import List, Tuple @@ -35,14 +36,13 @@ if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension -elif "paddle" in frameworks: - from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension CMakeBuildExtension = get_build_ext(BuildExtension) +archs = cuda_archs() class TimedBdist(bdist_wheel): @@ -57,7 +57,7 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" - cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] + cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None @@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch"]) - test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + test_reqs.extend(["numpy", "torchvision", "prettytable"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) - test_reqs.extend(["numpy", "praxis"]) - if "paddle" in frameworks: - install_reqs.append("paddlepaddle-gpu") - test_reqs.append("numpy") + # test_reqs.extend(["numpy", "praxis"]) + test_reqs.extend(["numpy"]) return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] @@ -135,7 +133,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], } else: setup_requires, install_requires, test_requires = setup_requirements() @@ -169,16 +166,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: current_file_path / "transformer_engine", ) ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", - ) - ) # Configure package setuptools.setup( diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index d8c8d99fac..081cd14eb4 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -5,7 +5,11 @@ cmake_minimum_required(VERSION 3.18) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 90) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 178dc5e8dd..ce78fcaae2 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,23 +3,33 @@ # See LICENSE for license information. add_executable(test_operator + test_cast.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu test_qdq.cu - test_cast_transpose.cu + test_cast_mxfp8.cu + test_dequantize_mxfp8.cu test_transpose.cu + test_cast_transpose.cu test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu test_normalization.cu + test_normalization_mxfp8.cu test_multi_cast_transpose.cu test_multi_padding.cu test_causal_softmax.cu + test_swizzle.cu ../test_common.cu) +find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) -target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) -target_compile_options(test_operator PRIVATE -O2) +target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) +target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) -gtest_discover_tests(test_operator) +gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index cec997d078..4224f199f4 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -21,58 +21,6 @@ using namespace transformer_engine; -namespace { - -// forward - -float gelu(const float x) { - return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x))); -} - -float silu(const float x) { - return x / (1 + expf(-x)); -} - -float relu(const float x) { - return x > 0 ? x : 0; -} - -float srelu(const float x) { - return x > 0 ? x * x : 0; -} - -float qgelu(const float x) { - return x / (1 + expf(-1.702f * x)); -} - -// backward - -float dgelu(const float x) { - const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x)); - return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + - 0.5f * (1.f + tanh_out); -} - -float dsilu(const float x) { - const float sigmoid = 1.f / (1 + expf(-x)); - return x * sigmoid * (1.f - sigmoid) + sigmoid; -} - -float drelu(const float x) { - return x > 0.f ? 1.f : 0.f; -} - -float dsrelu(const float x) { - return fmaxf(2.f * x, 0.f); -} - -float dqgelu(const float x) { - const float sigmoid = 1.f / (1 + expf(-1.702f * x)); - return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid; -} - -} // namespace - template void compute_ref_act_cast(const IT *input_h, OT *output_h, @@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h, const size_t H) { CT amax = 0.; + #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); @@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h, const size_t N, const size_t H) { using CT = float; + #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); @@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C const int col = H * 2; + #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT gelu_elt = static_cast(input_h[i * col + j]); @@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h const int col = H * 2; using CT = float; + #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT grad = static_cast(grad_h[i * H + j]); @@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output({ N, H }, otype); - Tensor igrad({ N, H }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype); + Tensor igrad("igrad", { N, H }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); @@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) { nvte_act(input.data(), output.data(), 0); float ref_amax; - compute_ref_act_cast(input.cpu_dptr(), ref_output.get(), + compute_ref_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) { nvte_dact(ograd.data(), input.data(), igrad.data(), 0); - compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(), + compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); @@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H * 2}, itype); - Tensor output({N, H}, otype); - Tensor igrad({ N, H * 2 }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H}, otype); + Tensor igrad("igrad", { N, H * 2 }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); @@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) { nvte_act(input.data(), output.data(), 0); float ref_amax; - compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(), + compute_ref_glu_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) { ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + auto [atol, rtol] = getTolerances(DType::kFloat32); + compareResults("amax", output.amax(), ref_amax, atol, rtol); + if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + const float ref_scale = 1.f / output.scale(); + compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol); + } } auto [atol, rtol] = getTolerances(otype); compareResults("output_gelu", output, ref_output.get(), atol, rtol); nvte_dact(ograd.data(), input.data(), igrad.data(), 0); - compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(), + compute_ref_dglu_act_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu new file mode 100644 index 0000000000..f57d1f035d --- /dev/null +++ b/tests/cpp/operator/test_cast.cu @@ -0,0 +1,130 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, + const size_t size, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i] = OutputType(scale * current); + } + *amax = current_max; +} + +template +void performTest(const std::vector& shape) { + using namespace test; + + const size_t full_size = product(shape); + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype); + + std::unique_ptr ref_output_c = std::make_unique(full_size); + + fillUniform(&input); + setRandomScale(&output_c); + + nvte_quantize(input.data(), output_c.data(), 0); + + float ref_amax; + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + full_size, &ref_amax, output_c.scale()); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); +} + +std::vector> test_cases = { + {16}, + {16000}, + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; +} // namespace + +class CastTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastTestSuite, TestCast) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu new file mode 100644 index 0000000000..1f0a9305d8 --- /dev/null +++ b/tests/cpp/operator/test_cast_dbias.cu @@ -0,0 +1,181 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref_cast_dbias(const IT *input_h, + const CT scale, + OT *output_c_h, + CT *amax_h, + IT *dbias_h, + const size_t N, + const size_t H) { + CT amax = 0.; + + std::vector acc_dbias(H, 0.); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT elt = static_cast(input_h[i * H + j]); + + // update amax + amax = std::abs(elt) > amax ? std::abs(elt) : amax; + + output_c_h[i * H + j] = static_cast(scale * elt); + + // dbias + acc_dbias[j] += elt; + } + } + + *amax_h = amax; + + for (size_t i = 0; i < H; i++) { + dbias_h[i] = static_cast(acc_dbias[i]); + } +} + +template +void performTest(const std::vector& shape) { + using namespace test; + using CType = fp32; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); + + Tensor input("input", shape, itype); + + Tensor output_c("output_c", shape, otype); + // dbias has the same data type with "output grad" + Tensor dbias("dbias", {H}, itype); + + fillUniform(&input); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(N*H); + std::unique_ptr ref_output_dbias = std::make_unique(H); + + CType ref_amax; + compute_ref_cast_dbias(input.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + ref_output_dbias.get(), + N, H); + + Tensor workspace; + + nvte_quantize_dbias(input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + rtol_dbias *= 4; + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; + +} // namespace; + + +class CastDBiasTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastDBiasTestSuite, TestCastDBias) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastDBiasTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu new file mode 100644 index 0000000000..20ea5c31f1 --- /dev/null +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -0,0 +1,196 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void compute_ref_cast_dbias_dgelu(const IT *input, + const IT *grad, + const CT scale, + OT *output_c, + CT *amax_h, + IT *dbias, + const size_t N, + const size_t H) { + CT amax = 0.; + + std::vector acc_dbias(H, 0.); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT in_elt = static_cast(input[i * H + j]); + const CT in_grad = static_cast(grad[i * H + j]); + + const CT elt = in_grad * static_cast(dgelu(static_cast(in_elt))); + const CT elt_abs = std::abs(elt); + + // update amax + if (elt_abs > amax) { + amax = elt_abs; + } + + output_c[i * H + j] = static_cast(scale * elt); + + // dbias + acc_dbias[j] += elt; + } + } + + *amax_h = amax; + + for (size_t i = 0; i < H; i++) { + dbias[i] = static_cast(acc_dbias[i]); + } +} + +template +void performTest(const std::vector& shape) { + using namespace test; + using CType = fp32; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + + Tensor output_c("output_c", shape, otype); + // dbias has the same data type with "output grad" + Tensor dbias("dbias", {H}, itype); + + fillUniform(&input); + fillUniform(&grad); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(N*H); + std::unique_ptr ref_output_dbias = std::make_unique(H); + + CType ref_amax; + compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + ref_output_dbias.get(), + N, H); + + Tensor workspace; + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + rtol_dbias *= 4; + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; + +} // namespace; + + +class CastDBiasDGeluTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastDBiasDGeluTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu new file mode 100644 index 0000000000..35ae462106 --- /dev/null +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void compute_ref_cast_dgated_swiglu(const IType * const grad, + const IType * const input, + const float scale, + OType * const output, + float * const amax_ptr, + const size_t rows, + const size_t cols) { + float amax = 0; + const size_t stride = cols * 2; + + #pragma omp parallel for reduction(max: amax) proc_bind(spread) + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < cols; j++) { + float grad_elt = static_cast(grad[i * cols + j]); + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + float after_dgate = grad_elt * silu(silu_elt); + + if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); } + if (abs(after_dgate) > amax) { amax = abs(after_dgate); } + + output[i * stride + j] = static_cast(scale * after_dsilu); + output[i * stride + cols + j] = static_cast(scale * after_dgate); + } + } + + *amax_ptr = amax; +} + +template +void performTest(const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + std::vector input_shape = shape; + input_shape[input_shape.size() - 1] *= 2; + + const size_t input_size = product(input_shape); + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + Tensor grad("grad", shape, itype); + Tensor input("input", input_shape, itype); + Tensor output_c("output_c", input_shape, otype); + + fillUniform(&grad); + fillUniform(&input); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(input_size); + + nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0); + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax; + compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + rows, + cols); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {217, 256}, + {1296}, + {5, 4, 3, 160}, +}; + +} // namespace + +class CastSwiGLUTestSuite + : public ::testing::TestWithParam>> {}; + +TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + if (size.back() % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, performTest(size););); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, CastSwiGLUTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo &info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu new file mode 100644 index 0000000000..cb38a5a74a --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -0,0 +1,636 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +template +void scale_block(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_c, + float* dbias, + fp8e8m0* output_scales, + const size_t scale_idx, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) { + float amax = 0.0f; + + // Find the absolute maximum value in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + dbias[j] += elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + amax = std::max(amax, std::abs(elt)); + } + } + + const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + output_scales[scale_idx] = biased_exponent; + + // Quantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + output_c[idx] = static_cast(elt * scale_reciprocal); + } + } +} + +template +void compute_ref_x1(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_c, + fp8e8m0* output_scales, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) +{ + std::vector output_dbias_fp32(cols, 0); + + const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; + const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; + + for (size_t ii = 0; ii < blocks_Y; ++ii) { + const size_t i_min = ii * block_size_Y; + const size_t i_max = std::min((ii + 1) * block_size_Y, rows); + for (size_t jj = 0; jj < blocks_X; ++jj) { + const size_t j_min = jj * block_size_X; + const size_t j_max = std::min((jj + 1) * block_size_X, cols); + const size_t scale_idx = ii * scales_stride + jj; + scale_block( + processing_method, input, grad, output_c, output_dbias_fp32.data(), + output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + } + } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); + } +} + +template +void compute_ref_x2(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { + compute_ref_x1( + processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, + rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1( + processing_method, input, grad, output_colwise, scales_colwise, output_dbias, + rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ + +template +void performTest_x1(const ProcessingMethod processing_method, + const std::vector& shape, + const bool rowwise, + const bool colwise, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + if (shape.size() < 2 && colwise) { + GTEST_SKIP(); + } + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows, + block_size_cols); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); + + std::unique_ptr ref_output_c = std::make_unique(rows * cols); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + nvte_gelu(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x1(processing_method, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + + const uint8_t * const gpu_scales_ptr = rowwise + ? output_c.rowwise_cpu_scale_inv_ptr() + : output_c.columnwise_cpu_scale_inv_ptr(); + + compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + + if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x2(const ProcessingMethod processing_method, + const std::vector& shape, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + if (shape.size() < 2) { + GTEST_SKIP(); + } + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); + + std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + nvte_dgelu(grad.data(), input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + nvte_gelu(input.data(), output.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x2(processing_method, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); + + if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} + +std::vector> matrix_sizes = { + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {256, 65536}, + {2048, 6144}, + {16384, 128}, + {32768, 160}, + {4096, 1632}, + {1024}, + {8, 32, 1024}, + {16, 8, 4, 512}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + {32, 32}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + +} // namespace + +class FusedCastMXFP8TestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ +switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ + case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ + case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ + case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ +} + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ +switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ + case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ + case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ + case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ + case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ +} + +TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const auto block_size = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); + const InputsFillCase fill_case = std::get<6>(GetParam()); + + // Skips non Act tests if the Activation type is not an identity + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + GTEST_SKIP(); + } + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + if (processing_method == ProcessingMethod::CAST_ACT) { + // Forward activations + ACT_FUNC_SWITCH(Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, matrix_size, + block_size.first, block_size.second, fill_case); + } + ); + ); + ); + } else { + DACT_FUNC_SWITCH(Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, matrix_size, + block_size.first, block_size.second, fill_case); + } + ); + ); + ); + } +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "Identity"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)) + "X" + + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + std::to_string(std::get<3>(info.param).first) + + "X" + std::to_string(std::get<3>(info.param).second) + + "X" + test::typeName(std::get<4>(info.param)) + + "X" + test::typeName(std::get<5>(info.param)) + + "X" + test::caseName(std::get<6>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu new file mode 100644 index 0000000000..6acbdefeab --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -0,0 +1,470 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void scale_block(const IType* grad, + const IType* input, + OType* output, + fp8e8m0* output_scales, + const size_t scale_idx, + const size_t scale_idx_gate, + float& thread_amax, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) { + + float block_amax = 0.0f; + float block_amax_gate = 0.0f; + const size_t stride = cols * 2; + + // Find the absolute maximum value in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + float gated_amax_act = 0; + float gated_amax_gate = 0; + + if constexpr (IS_DGATED) { + const float grad_elt = static_cast(grad[i * cols + j]); + const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + const float after_dgate = silu(silu_elt) * grad_elt; + gated_amax_act = abs(after_dsilu); + gated_amax_gate = abs(after_dgate); + } else { + const float after_silu = silu(silu_elt) * gate_elt; + gated_amax_act = abs(after_silu); + } + + if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } + if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } + } + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * + Quantized_Limits::max_reciprocal()); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + output_scales[scale_idx] = biased_exponent; + float scale_reciprocal_gate = 1; + if constexpr (IS_DGATED) { + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * + Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent); + output_scales[scale_idx_gate] = biased_exponent; + } + + + // Quantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + if constexpr (IS_DGATED) { + const float grad_elt = static_cast(grad[i * cols + j]); + const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + const float after_dgate = silu(silu_elt) * grad_elt; + output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); + output[i * stride + cols + j] = static_cast(after_dgate * + scale_reciprocal_gate); + } else { + const float after_silu = silu(silu_elt) * gate_elt; + output[i * cols + j] = static_cast(after_silu * scale_reciprocal); + } + + } + } + thread_amax = std::max(thread_amax, block_amax); + thread_amax = std::max(thread_amax, block_amax_gate); +} + +template +void compute_ref_x1(const IType* grad, + const IType* input, + OType* output, + fp8e8m0* output_scales, + float& ref_amax, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) { + const size_t tile_size_Y = std::max(32lu, block_size_Y); + const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; + const size_t blocks_per_tile_X = tile_size_X / block_size_X; + + float amax = 0; + #pragma omp parallel reduction(max: amax) proc_bind(spread) + { + float thread_amax = 0; + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { + const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; + const size_t block_offset_Y = ii * block_size_Y; + const size_t i_min = tile_offset_Y + block_offset_Y; + if (i_min >= rows) continue; + const size_t i_max = std::min(i_min + block_size_Y, rows); + + for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { + const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; + const size_t block_offset_X = jj * block_size_X; + const size_t j_min = tile_offset_X + block_offset_X; + if (j_min >= cols) continue; + const size_t j_max = std::min(j_min + block_size_X, cols); + + const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; + const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + + cols / block_size_X; + scale_block( + grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, + thread_amax, i_min, i_max, j_min, j_max, cols); + } + } + } + if (thread_amax > amax) { + amax = thread_amax; + } + } + ref_amax = amax; +} + +template +void compute_ref_x2(const IType* grad, + const IType* input, + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + float& ref_amax, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { + compute_ref_x1( + grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1( + grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x1(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); + NVTE_CHECK(rowwise || colwise); + + // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; + // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; + // std::cout << "blocks_Y: " << blocks_Y << std::endl; + // std::cout << "blocks_X: " << blocks_X << std::endl; + // std::cout << "scales_stride: " << scales_stride << std::endl; + + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); + + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const std::array scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows, + block_size_cols); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + Tensor output("output", std::vector{ rows, output_cols }, otype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * output_cols); + std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + for (size_t i = 0; i < blocks_Y * blocks_X; ++i) { + ref_output_scales[i] = 0; + } + + // fillCase(&grad, fill_case); + if constexpr (IS_DGATED) { + fillUniform(&grad); + } + fillUniform(&input); + + if constexpr (IS_DGATED) { + nvte_dswiglu(grad.data(), input.data(), output.data(), 0); + } else { + nvte_swiglu(input.data(), output.data(), 0); + } + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax = 0; + compute_ref_x1(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_scales.get(), + ref_amax, + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + + const uint8_t * const gpu_scales_ptr = rowwise + ? output.rowwise_cpu_scale_inv_ptr() + : output.columnwise_cpu_scale_inv_ptr(); + if (rowwise) { + compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } else { + compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x2(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); + + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor output("output", std::vector{ rows, output_cols }, otype, + true, true, NVTE_MXFP8_1D_SCALING); + + std::unique_ptr ref_output_rowwise = std::make_unique(rows * output_cols); + std::unique_ptr ref_output_colwise = std::make_unique(rows * output_cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + + for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) { + ref_scales_rowwise[i] = 0; + } + for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) { + ref_scales_colwise[i] = 0; + } + + // fillCase(&grad, fill_case); + if constexpr (IS_DGATED) { + fillUniform(&grad); + } + fillUniform(&input); + + if constexpr (IS_DGATED) { + nvte_dswiglu(grad.data(), input.data(), output.data(), 0); + } else { + nvte_swiglu(input.data(), output.data(), 0); + } + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax = 0; + compute_ref_x2(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); +} + +std::vector> matrix_sizes = { + {1, 32}, + {16, 64}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {768, 1024}, + {65536, 128}, + {16384, 1632}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + {32, 32}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +std::vector is_dgated_op = { + true, + false +}; + +} // namespace + +class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase, + bool>> {}; + +TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto matrix_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const InputsFillCase fill_case = std::get<4>(GetParam()); + const bool IS_DGATED = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, + if (block_size.first == 1 || block_size.second == 1) { + if (IS_DGATED) { + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } else { + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } + } else { + if (IS_DGATED) { + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } else { + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastMXFP8_GatedActTestSuite, + ::testing::Combine( + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), + ::testing::ValuesIn(is_dgated_op)), + [](const testing::TestParamInfo& info) { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + test::caseName(std::get<4>(info.param)) + "X" + + (std::get<5>(info.param) ? "DGATED" : "GATED"); + return name; + }); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 05fcafb0b1..830682eec3 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -14,7 +14,7 @@ #include #include -#include +#include #include "../test_common.h" using namespace transformer_engine; @@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output_c({ N, H }, otype); - Tensor output_t({ H, N }, otype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); std::unique_ptr ref_output_c = std::make_unique(N * H); std::unique_ptr ref_output_t = std::make_unique(N * H); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); - nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0); + nvte_quantize(input.data(), output.data(), 0); float ref_amax; - compute_ref(input.cpu_dptr(), ref_output_c.get(), + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), ref_output_t.get(), N, H, &ref_amax, - output_c.scale()); + output.scale()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } std::vector> test_cases = {{2048, 12288}, diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 72d890f8e9..53918e2699 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -15,7 +15,7 @@ #include #include -#include +#include #include "../test_common.h" using namespace transformer_engine; @@ -64,26 +64,23 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - DType ctype = TypeInfo::dtype; - Tensor input({N, H}, itype); + Tensor input("input", {N, H}, itype); - Tensor output_c({N, H}, otype); - Tensor output_t({ H, N}, otype); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; - compute_ref_cast_transpose_dbias(input.cpu_dptr(), - output_c.scale(), + compute_ref_cast_transpose_dbias(input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -92,22 +89,20 @@ void performTest(const size_t N, const size_t H) { Tensor workspace; - nvte_cast_transpose_dbias(input.data(), - output_c.data(), - output_t.data(), - dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias(input.data(), + output.data(), + dbias.data(), + workspace.data(), + 0); - workspace = Tensor(workspace.shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_cast_transpose_dbias(input.data(), - output_c.data(), - output_t.data(), - dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias(input.data(), + output.data(), + dbias.data(), + workspace.data(), + 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -115,17 +110,17 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); auto [atol_dbias, rtol_dbias] = getTolerances(itype); rtol_dbias *= 4; - compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias); + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } std::vector> test_cases = {{64, 400}, diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index d3ba31fa53..15c7d8d665 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -75,29 +75,26 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - DType ctype = TypeInfo::dtype; - Tensor input({N, H}, itype); - Tensor gelu_input({N, H}, itype); + Tensor input("input", {N, H}, itype); + Tensor gelu_input("gelu_input", {N, H}, itype); - Tensor output_c({N, H}, otype); - Tensor output_t({ H, N}, otype); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); fillUniform(&gelu_input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; - compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr(), - gelu_input.cpu_dptr(), - output_c.scale(), + compute_ref_cast_transpose_dbias_dgelu(input.rowwise_cpu_dptr(), + gelu_input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -108,19 +105,17 @@ void performTest(const size_t N, const size_t H) { nvte_cast_transpose_dbias_dgelu(input.data(), gelu_input.data(), - output_c.data(), - output_t.data(), + output.data(), dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_cast_transpose_dbias_dgelu(input.data(), gelu_input.data(), - output_c.data(), - output_t.data(), + output.data(), dbias.data(), workspace.data(), 0); @@ -131,18 +126,18 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); auto [atol_dbias, rtol_dbias] = getTolerances(itype); rtol_dbias *= 4; - compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias); + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } std::vector> test_cases = {{64, 400}, diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index 03cec4e658..ae2da7bad2 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -74,24 +74,22 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor grad({N, H}, itype); - Tensor input({N, H * 2}, itype); - Tensor output_c({N, H * 2}, otype); - Tensor output_t({H * 2, N}, otype); + Tensor grad("grad", {N, H}, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H * 2}, otype, true, true); fillUniform(&grad); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N * H * 2); std::unique_ptr ref_output_t = std::make_unique(N * H * 2); - nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0); + nvte_dgeglu_cast_transpose(grad.data(), input.data(), output.data(), 0); CType ref_amax; - compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr(), input.cpu_dptr(), - output_c.scale(), ref_output_c.get(), ref_output_t.get(), + compute_ref_cast_transpose_dgated_gelu(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -100,14 +98,14 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } std::vector> test_cases = {{64, 400}, {4096, 2048}, {768, 2816}, diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu index 5401b03296..2fdc0a524d 100644 --- a/tests/cpp/operator/test_causal_softmax.cu +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -153,11 +153,11 @@ void performTest( DType itype = TypeInfo::dtype; - Tensor data_in({ batches, heads, rows, cols }, itype); - Tensor softmax_out({ batches, heads, rows, cols }, itype); - Tensor softmax_in({ batches, heads, rows, cols }, itype); - Tensor grads_in({ batches, heads, rows, cols }, itype); - Tensor grads_out({ batches, heads, rows, cols }, itype); + Tensor data_in("data_in", { batches, heads, rows, cols }, itype); + Tensor softmax_out("softmax_out", { batches, heads, rows, cols }, itype); + Tensor softmax_in("softmax_in", { batches, heads, rows, cols }, itype); + Tensor grads_in("grads_in", { batches, heads, rows, cols }, itype); + Tensor grads_out("grads_out", { batches, heads, rows, cols }, itype); const size_t elements_total = batches * heads * rows * cols; std::unique_ptr softmax_out_ref = std::make_unique(elements_total); @@ -175,9 +175,9 @@ void performTest( // Reference implementations - compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr(), + compute_fwd_ref(softmax_out_ref.get(), data_in.rowwise_cpu_dptr(), compute_buffer.get(), scaling_factor, batches, heads, rows, cols); - compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr(), softmax_in.cpu_dptr(), + compute_bwd_ref(grads_out_ref.get(), grads_in.rowwise_cpu_dptr(), softmax_in.rowwise_cpu_dptr(), compute_buffer.get(), scaling_factor, batches, heads, rows, cols); cudaDeviceSynchronize(); @@ -187,8 +187,8 @@ void performTest( if(itype == DType::kBFloat16) { atol = 1e-3; } - compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol); - compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol); + compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), true, atol, rtol); + compareResults("softmax_bwd", grads_out, grads_out_ref.get(), true, atol, rtol); } // [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor] diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu new file mode 100644 index 0000000000..701deb38bb --- /dev/null +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -0,0 +1,452 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void dequantize_block(const InputType* input, + OutputType* output, + fp8e8m0* scales, + const size_t scale_idx, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) +{ + const fp8e8m0 biased_exponent = scales[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float elem_scale = block_scale; + + // Dequantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float elt = static_cast(input[idx]); + output[idx] = static_cast(elt * elem_scale); + } + } +} + +template +void compute_ref_x1(const InputType* input, + OutputType* output, + fp8e8m0* scales, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) +{ + const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; + const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; + + for (size_t ii = 0; ii < blocks_Y; ++ii) { + const size_t i_min = ii * block_size_Y; + const size_t i_max = std::min((ii + 1) * block_size_Y, rows); + for (size_t jj = 0; jj < blocks_X; ++jj) { + const size_t j_min = jj * block_size_X; + const size_t j_max = std::min((jj + 1) * block_size_X, cols); + const size_t scale_idx = ii * scales_stride + jj; + dequantize_block( + input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols); + } + } +} + +template +void compute_ref_x2(const InputType* input, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ + compute_ref_x1(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +void generate_scales(fp8e8m0 * const scales_ref, + fp8e8m0 * const scales, + const size_t blocks_num, + std::mt19937& gen, + std::uniform_int_distribution dis) +{ + for (size_t i = 0; i < blocks_num; ++i) { + const fp8e8m0 val = dis(gen); + scales_ref[i] = val; + scales[i] = val; + } +} + +template +void generate_data(InputType * const data, + const size_t rows, + const size_t cols, + std::mt19937& gen, + std::uniform_real_distribution<>& dis, + std::uniform_real_distribution<>& dis_sign) +{ + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const bool is_negative = (dis_sign(gen) < 0.0); + double val = dis(gen); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } + } +} + +template +void fill_tensor_data(Tensor& input, + fp8e8m0 * const scales_rowwise, + fp8e8m0 * const scales_colwise, + const bool is_rowwise_scaling, + const bool is_colwise_scaling, + const size_t rows, + const size_t cols, + const size_t blocks_num_rowwise, + const size_t blocks_num_colwise) +{ + const double minAbs = Numeric_Traits::minNorm; + const double maxAbs = Numeric_Traits::maxNorm; + static std::mt19937 gen(12345); + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + std::uniform_int_distribution int_dis(0, 255); + + if (is_rowwise_scaling) { + generate_scales(scales_rowwise, input.rowwise_cpu_scale_inv_ptr(), blocks_num_rowwise, gen, int_dis); + generate_data(input.rowwise_cpu_dptr(), rows, cols, gen, dis, dis_sign); + } + + if (is_colwise_scaling) { + generate_scales(scales_colwise, input.columnwise_cpu_scale_inv_ptr(), blocks_num_colwise, gen, int_dis); + generate_data(input.columnwise_cpu_dptr(), rows, cols, gen, dis, dis_sign); + } + + input.from_cpu(); +} + +// Dequantize along single dimension (either row- or columnwise) +template +void performTest_x1(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise; + const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise; + + Tensor input("input", { rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + // Output data are written to the rowwise ptr regardless of the scaling direction + Tensor output("output", { rows, cols }, otype, true, false); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr scales = std::make_unique(blocks_num); + + fill_tensor_data(input, scales.get(), scales.get(), rowwise, colwise, rows, cols, + blocks_num_rowwise, blocks_num_colwise); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + InputType * data_ptr = rowwise + ? input.rowwise_cpu_dptr() + : input.columnwise_cpu_dptr(); + + compute_ref_x1(data_ptr, + ref_output.get(), + scales.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), true, atol, rtol); +} + +// Dequantize along single dimension (either row- or columnwise) +template +void performTest_quantize_then_dequantize(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType in_type = TypeInfo::dtype; + DType intermed_type = TypeInfo::dtype; + DType out_type = TypeInfo::dtype; + + std::unique_ptr input_cpu = std::make_unique(rows * cols); + std::unique_ptr quantized_cpu = std::make_unique(rows * cols); + std::unique_ptr output_cpu = std::make_unique(rows * cols); + + // input --> quantized --> output (dequantized) + // input == output + Tensor input("input", { rows, cols }, in_type); + Tensor quantized("quantized", { rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + // Output data are written to the rowwise ptr regardless of the scaling direction + Tensor output("output", { rows, cols }, out_type, true, false); + + // fillCase(&input, InputsFillCase::minNorm_to_maxNorm); + fillCase(&input, InputsFillCase::uniform); + + const size_t copy_size = sizeof(InputType) * rows * cols; + cudaMemcpy(input_cpu.get(), input.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + nvte_quantize(input.data(), quantized.data(), 0); + cudaDeviceSynchronize(); + + const size_t copy_size_quantized = sizeof(IntermediateType) * rows * cols; + if (rowwise) { + cudaMemcpy(quantized_cpu.get(), quantized.rowwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost); + } + if (colwise) { + cudaMemcpy(quantized_cpu.get(), quantized.columnwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost); + } + + nvte_dequantize(quantized.data(), output.data(), 0); + cudaDeviceSynchronize(); + + cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(intermed_type); + compareResults("Quantize-Dequantize", input, output_cpu.get(), true, atol, rtol); +} + +// Dequantize along both dimensions (row- and columnwise) +template +void performTest_x2(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) +{ + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t scales_stride_rowwise = blocks_X_rowwise; + const size_t scales_stride_colwise = blocks_X_colwise; + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + Tensor input("input", { rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output("output", { rows, cols }, otype); + + std::unique_ptr ref_output_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_num_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_num_colwise); + + constexpr bool rowwise = true; + constexpr bool colwise = true; + fill_tensor_data(input, ref_scales_rowwise.get(), ref_scales_colwise.get(), + rowwise, colwise, rows, cols, blocks_num_rowwise, blocks_num_colwise); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x2(input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol); +} + +std::vector> tensor_dims = { + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {768, 1024}, + // {2048, 12288}, + // {65536, 128}, + // {16384, 1632}, + // {16384, 6144}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + // {32, 32}, +}; + +} // namespace + +class DequantizeMXFP8TestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + bool>> {}; + +TEST_P(DequantizeMXFP8TestSuite, TestDequantizeMXFP8) +{ + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto tensor_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const bool quantize_then_dequantize = std::get<4>(GetParam()); + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + + // Skip tests for dequantization along both dimensions + if (rowwise && colwise) { + GTEST_SKIP(); + } + + // Skip cases with invalid alignment + if (rowwise && tensor_size.second % 32 != 0) { + GTEST_SKIP(); + } + if (colwise && tensor_size.first % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + if (quantize_then_dequantize) { + // Mind the order of the Output/Input template parameters + performTest_quantize_then_dequantize( + tensor_size.first, tensor_size.second, rowwise, colwise); + } else { + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1(tensor_size.first, tensor_size.second, + rowwise, colwise); + } else { + performTest_x2(tensor_size.first, tensor_size.second, + block_size.first, block_size.second); + } + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(false)), + [](const testing::TestParamInfo& info) + { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + (std::get<4>(info.param) ? "QD" : "D"); + return name; + } +); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index e9f420e5b1..f07138caca 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -69,7 +69,7 @@ void performTest() { const size_t num_tensors = tensor_dims.size(); // Buffers for Transformer Engine implementation - std::vector input_list, output_c_list, output_t_list; + std::vector input_list, output_list; // Buffers for reference implementation std::vector> ref_input_list; @@ -81,25 +81,23 @@ void performTest() { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_c_list.emplace_back(Tensor({ height, width }, otype)); - output_t_list.emplace_back(Tensor({ width, height }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), + { height, width }, otype, true, true)); auto& input = input_list.back(); - auto& output_c = output_c_list.back(); - auto& output_t = output_t_list.back(); + auto& output = output_list.back(); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); ref_input_list.emplace_back(height*width); ref_output_c_list.emplace_back(height*width); ref_output_t_list.emplace_back(width*height); - std::copy(input.cpu_dptr(), - input.cpu_dptr() + height * width, + std::copy(input.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); - ref_scale_list[tensor_id] = output_c.scale(); + ref_scale_list[tensor_id] = output.scale(); ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; } @@ -115,8 +113,7 @@ void performTest() { }; nvte_multi_cast_transpose(num_tensors, make_nvte_vector(input_list).data(), - make_nvte_vector(output_c_list).data(), - make_nvte_vector(output_t_list).data(), + make_nvte_vector(output_list).data(), 0); // Reference implementation @@ -136,23 +133,23 @@ void performTest() { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", - output_c_list[tensor_id].amax(), + output_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); compareResults("scale_inv", - output_c_list[tensor_id].scale_inv(), - 1.f / output_c_list[tensor_id].scale(), + output_list[tensor_id].rowwise_scale_inv(), + 1.f / output_list[tensor_id].scale(), atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", - output_c_list[tensor_id], + output_list[tensor_id], ref_output_c_list[tensor_id].data(), - atol, rtol); + true, atol, rtol); compareResults("output_t", - output_t_list[tensor_id], + output_list[tensor_id], ref_output_t_list[tensor_id].data(), - atol, rtol); + false, atol, rtol); } } diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu index 23c824e857..b8475fe561 100644 --- a/tests/cpp/operator/test_multi_padding.cu +++ b/tests/cpp/operator/test_multi_padding.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -84,8 +85,8 @@ void performTest() { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; const size_t padded_height = (height + align - 1) / align * align; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_list.emplace_back(Tensor({ padded_height, width }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), { padded_height, width }, otype)); auto& input = input_list.back(); auto& output = output_list.back(); @@ -95,8 +96,8 @@ void performTest() { ref_input_list.emplace_back(height*width); ref_output_list.emplace_back(padded_height*width); - std::copy(input.cpu_dptr(), - input.cpu_dptr() + height * width, + std::copy(input.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; @@ -134,6 +135,7 @@ void performTest() { compareResults("output", output_list[tensor_id], ref_output_list[tensor_id].data(), + true, atol, rtol); } } diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 58152864eb..0004c2ce74 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -176,6 +175,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; return; } + + if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { + GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; + } + using WeightType = InputType; DType itype = TypeInfo::dtype; DType wtype = TypeInfo::dtype; @@ -187,16 +191,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, return; } - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); - Tensor dz({ N, H }, wtype); - Tensor dx({ N, H }, itype); - Tensor dgamma({ H }, wtype); - Tensor dbeta({ H }, wtype); + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor dz("dz", { N, H }, wtype); + Tensor dx("dx", { N, H }, itype); + Tensor dgamma("dgamma", { H }, wtype); + Tensor dbeta("dbeta", { H }, wtype); Tensor workspace_fwd, workspace_bwd; fillUniform(&input); @@ -226,7 +230,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -236,7 +240,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -246,7 +250,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -255,7 +259,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), workspace_bwd.data(), @@ -272,23 +276,24 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, mu.to_cpu(); rsigma.to_cpu(); float ref_amax; - compute_ref_stats(norm_type, input.cpu_dptr(), ref_mu.get(), + compute_ref_stats(norm_type, input.rowwise_cpu_dptr(), ref_mu.get(), ref_rsigma.get(), N, H, epsilon); float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(norm_type, input.cpu_dptr(), - gamma.cpu_dptr(), - beta.cpu_dptr(), + compute_ref_output(norm_type, input.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + beta.rowwise_cpu_dptr(), ref_output.get(), - mu.cpu_dptr(), - rsigma.cpu_dptr(), + mu.rowwise_cpu_dptr(), + rsigma.rowwise_cpu_dptr(), N, H, &ref_amax, ref_scale, zero_centered_gamma, use_cudnn); - compute_ref_backward(norm_type, dz.cpu_dptr(), input.cpu_dptr(), - mu.cpu_dptr(), rsigma.cpu_dptr(), - gamma.cpu_dptr(), + compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), N, H, zero_centered_gamma, use_cudnn); @@ -301,25 +306,25 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); rtol_stats = 5e-5; - compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); + compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats); auto [atol, rtol] = getTolerances(otype); if (otype == DType::kFloat32) { atol = 5e-7; } - compareResults("output", z, ref_output.get(), atol, rtol); + compareResults("output", z, ref_output.get(), true, atol, rtol); double atol_bwd = 5e-4; double rtol_bwd = 5e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); - compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); + compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd); + compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd); + compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd); } std::vector> test_cases = { @@ -357,24 +362,24 @@ TEST_P(NormTestSuite, TestNorm) { } INSTANTIATE_TEST_SUITE_P( - OperatorTest, - NormTestSuite, - ::testing::Combine( - ::testing::Values(false), //TODO: enabling tests for cudnn backend - ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo& info) { + OperatorTest, + NormTestSuite, + ::testing::Combine( + ::testing::Values(true, false), + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; -std::string name = - backend + - normToString.at(std::get<1>(info.param)) + "_" + - test::typeName(std::get<2>(info.param)) + "X" + - test::typeName(std::get<3>(info.param)) + "X" + - std::to_string(std::get<4>(info.param).first) + "X" + - std::to_string(std::get<4>(info.param).second) + "X" + - std::to_string(std::get<5>(info.param)); - return name; - }); + std::string name = + backend + + normToString.at(std::get<1>(info.param)) + "_" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + std::to_string(std::get<4>(info.param).first) + "X" + + std::to_string(std::get<4>(info.param).second) + "X" + + std::to_string(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu new file mode 100644 index 0000000000..d1bdb6203b --- /dev/null +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -0,0 +1,337 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +using fp8e8m0 = byte; + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RMSNorm"} +}; + +template +void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, + size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ + + const size_t block_size_Y = scaling_mode_x; // mind the mapping Y <-- x + const size_t block_size_X = scaling_mode_y; // and X <-- y + const size_t tile_size_Y = std::max(32lu, block_size_Y); + const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; + const size_t blocks_per_tile_X = tile_size_X / block_size_X; + const size_t blocks_per_row = (cols + block_size_X - 1) / block_size_X; + + #pragma omp parallel for proc_bind(spread) schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { + const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; + const size_t block_offset_Y = ii * block_size_Y; + const size_t i_min = tile_offset_Y + block_offset_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + + for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { + const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; + const size_t block_offset_X = jj * block_size_X; + const size_t j_min = tile_offset_X + block_offset_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + const size_t mx_scale_idx = block_idx_Y * blocks_per_row + block_idx_X; + + // TODO: padded SFs i.e. (4,128) + const float scale_inv = exp2f(static_cast(scale_ptr[mx_scale_idx]) - FP32_EXPONENT_BIAS); + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float elem = static_cast(input_ptr[idx]); + output_ptr[idx] = static_cast(elem * scale_inv); + } + } + } + } + } +} + +template +void dequantize_2x(Tensor& input, Tensor& output, bool is_training) +{ + input.to_cpu(); + auto scaling_mode = input.scaling_mode(); + assert(input.rowwise_shape().ndim == 2); + assert(input.columnwise_shape().ndim == 2); + + dequantize_1x_kernel(input.rowwise_cpu_dptr(), + input.rowwise_cpu_scale_inv_ptr(), + output.rowwise_cpu_dptr(), + input.rowwise_shape().data[0], input.rowwise_shape().data[1], + 1, 32); + if (is_training) + dequantize_1x_kernel(input.columnwise_cpu_dptr(), + input.columnwise_cpu_scale_inv_ptr(), + output.columnwise_cpu_dptr(), + input.columnwise_shape().data[0], input.columnwise_shape().data[1], + 32, 1); +} + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + + #pragma omp parallel for proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + compute_t m; + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + OutputType* output, + const bool zero_centered_gamma){ + using compute_t = float; + + #pragma omp parallel for proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = static_cast(gamma[j]); + if (zero_centered_gamma) { + g += 1.0; + } + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = tmp; + } + } +} + +template +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using WeightType = InputType; + DType itype = TypeInfo::dtype; + DType wtype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor workspace; + + + fillUniform(&input); + fillUniform(&gamma); + fillUniform(&beta); + + // Forward kernel + float epsilon = 1e-5; + if (norm_type == NormType::LayerNorm){ + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + } + + Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); + + dequantize_2x(z, dequantized_output, is_training); + + // Reference implementations + std::unique_ptr ref_mu = std::make_unique(N); + std::unique_ptr ref_rsigma = std::make_unique(N); + std::unique_ptr ref_output = std::make_unique(N * H); + + + compute_ref_stats(norm_type, input.rowwise_cpu_dptr(), ref_mu.get(), + ref_rsigma.get(), N, H, epsilon); + // use the GPU stats to tighten the tolerances + float *ref_mu_ptr, *ref_rsigma_ptr; + if (is_training){ + mu.to_cpu(); + rsigma.to_cpu(); + ref_mu_ptr = mu.rowwise_cpu_dptr(); + ref_rsigma_ptr = rsigma.rowwise_cpu_dptr(); + } else { + ref_mu_ptr = ref_mu.get(); + ref_rsigma_ptr = ref_rsigma.get(); + } + compute_ref_output(norm_type, input.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + beta.rowwise_cpu_dptr(), + ref_mu_ptr, + ref_rsigma_ptr, + N, H, + ref_output.get(), + zero_centered_gamma); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); + rtol_stats = 5e-5; + if (is_training){ + compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats); + } + + float atol, rtol; + if (otype == DType::kFloat8E5M2){ + atol = 1.25e-1; + rtol = 1.25e-1; + } else if (otype == DType::kFloat8E4M3){ + if (itype == DType::kBFloat16){ + atol = 7e-2; + rtol = 7e-2; + } else { + atol = 6.25e-2; + rtol = 6.25e-2; + } + } + compareResults("output_rowwise", dequantized_output, ref_output.get(), true, atol, rtol, false); + if (is_training) + compareResults("output_colwise", dequantized_output, ref_output.get(), false, atol, rtol, false); +} + +std::vector> test_cases = { + {32, 32}, + {768, 2304}, + {2048, 12288}, +}; + +std::vector norms = { + NormType::LayerNorm, + NormType::RMSNorm +}; + +} // namespace + +class MxNormTestSuite : public ::testing::TestWithParam< std::tuple, + bool, bool>> {}; + +TEST_P(MxNormTestSuite, TestMxNorm) { + using namespace transformer_engine; + using namespace test; + + const NormType norm_type = std::get<0>(GetParam()); + const DType input_type = std::get<1>(GetParam()); + const DType output_type = std::get<2>(GetParam()); + const auto size = std::get<3>(GetParam()); + const bool zero_centered_gamma = std::get<4>(GetParam()); + const bool is_training = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxNormTestSuite, + ::testing::Combine( + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(true, false), + ::testing::Values(true, false)), + [](const testing::TestParamInfo& info) { + std::string name = normToString.at(std::get<0>(info.param)) + "_" + + test::typeName(std::get<1>(info.param)) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + std::to_string(std::get<3>(info.param).first) + "X" + + std::to_string(std::get<3>(info.param).second) + "X" + + std::to_string(std::get<4>(info.param)) + "out" + + std::to_string(int(std::get<5>(info.param)) + 1) + "x"; + return name; + }); diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index 76f049360a..3c12cef865 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -58,18 +58,18 @@ void performTestQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); fillUniform(&input); setRandomScale(&output); - nvte_fp8_quantize(input.data(), output.data(), 0); + nvte_quantize(input.data(), output.data(), 0); float ref_amax; - compute_ref_q(input.cpu_dptr(), ref_output.get(), + compute_ref_q(input.rowwise_cpu_dptr(), ref_output.get(), N, &ref_amax, output.scale()); cudaDeviceSynchronize(); @@ -79,7 +79,7 @@ void performTestQ(const size_t N) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); auto [atol, rtol] = getTolerances(otype); - compareResults("output_q", output, ref_output.get(), atol, rtol); + compareResults("output_q", output, ref_output.get(), true, atol, rtol); } template @@ -89,24 +89,24 @@ void performTestDQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); fillUniform(&input); - nvte_fp8_dequantize(input.data(), output.data(), 0); + nvte_dequantize(input.data(), output.data(), 0); - compute_ref_dq(input.cpu_dptr(), ref_output.get(), - N, input.scale_inv()); + compute_ref_dq(input.rowwise_cpu_dptr(), ref_output.get(), + N, input.rowwise_scale_inv()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); auto [atol, rtol] = getTolerances(otype); - compareResults("output_dq", output, ref_output.get(), atol, rtol); + compareResults("output_dq", output, ref_output.get(), true, atol, rtol); } std::vector qdq_test_cases = {2048* 12288, diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu new file mode 100644 index 0000000000..f6e0da057a --- /dev/null +++ b/tests/cpp/operator/test_swizzle.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; + +constexpr int MAT_TILE_DIM_M = 128; +constexpr int MAT_TILE_DIM_K = 128; + +template +void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, + const size_t M, const size_t K) { + + constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; + constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; + constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; + + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + + int tile_id_m = m / SF_TILE_DIM_M; + int tile_id_k = k / SF_TILE_DIM_K; + int m_in_tile = m % SF_TILE_DIM_M; + int k_in_tile = k % SF_TILE_DIM_K; + + int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; + int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; + + int tile_output_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; + int out_index = tile_output_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; + if constexpr(row_scaling) + h_output[out_index] = h_input[k + m * K]; + else + h_output[out_index] = h_input[k * M + m]; + } + } +} + +void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + std::vector scaling_mode = {SF_MODE_X, SF_MODE_Y, 0}; + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + std::unique_ptr ref_output = std::make_unique(scale_shape[0] * scale_shape[1]); + + nvte_swizzle_scaling_factors(input.data(), output.data(), 0); + + if (rowwise) + compute_ref_swizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0], scale_shape[1]); + else + compute_ref_swizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[1], scale_shape[0]); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output.to_cpu(); + if (rowwise) { + compareResults("output_swizzle", output.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("output_swizzle", output.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } +} + +class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + + +TEST_P(SwizzleTestSuite, TestSwizzle) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestSwizzle1D(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + +namespace { + +std::vector> num_tiles = { + {1, 1}, + {1, 132}, + {132, 1}, + {65, 256}, + {65, 257}, + {65, 258}, + {65, 259}, +}; + +std::vector> scaling_mode = { + {true, false}, + {false, true} +}; + +std::vector transa = {true, false}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_transpose.cu b/tests/cpp/operator/test_transpose.cu index 0852ddf7c3..00dd241c92 100644 --- a/tests/cpp/operator/test_transpose.cu +++ b/tests/cpp/operator/test_transpose.cu @@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) { DType dtype = TypeInfo::dtype; - Tensor input({ N, H }, dtype); - Tensor output({ H, N }, dtype); + Tensor input("input", { N, H }, dtype); + Tensor output("output", { H, N }, dtype); std::unique_ptr ref_output = std::make_unique(N * H); @@ -46,13 +46,13 @@ void performTest(const size_t N, const size_t H) { nvte_transpose(input.data(), output.data(), 0); - compute_ref(input.cpu_dptr(), ref_output.get(), N, H); + compute_ref(input.rowwise_cpu_dptr(), ref_output.get(), N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); auto [atol, rtol] = getTolerances(dtype); - compareResults("output", output, ref_output.get(), atol, rtol); + compareResults("output", output, ref_output.get(), true, atol, rtol); } std::vector> test_cases = {{2048, 12288}, diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 84cc11673b..ec4a9bdbb7 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -10,14 +10,24 @@ #include #include #include +#include +#include +#include #include +#include #include #include "util/logging.h" namespace test { +size_t create_seed_from_tensor_name(const std::string& tensor_name) { + auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + + "/" + tensor_name; + return std::hash{}(full_name); +} + std::vector all_fp_types = {DType::kFloat32, DType::kFloat16, DType::kBFloat16, @@ -50,102 +60,379 @@ const std::string &typeName(DType type) { {DType::kFloat16, "float16"}, {DType::kBFloat16, "bfloat16"}, {DType::kFloat8E4M3, "float8e4m3"}, - {DType::kFloat8E5M2, "float8e5m2"}}; + {DType::kFloat8E5M2, "float8e5m2"}, + {DType::kFloat8E8M0, "float8e8m0"}}; return name_map.at(type); } -size_t product(const NVTEShape &shape) { +const std::string& caseName(InputsFillCase type) { + static const std::unordered_map name_map = { + {InputsFillCase::uniform, "uniform"}, + {InputsFillCase::zeros, "zeros"}, + {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"}, + {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"}, + {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}}; + return name_map.at(type); +} + +size_t product(const NVTEShape &shape, size_t begin, size_t end) { size_t ret = 1; - for (size_t i = 0; i < shape.ndim; ++i) { + NVTE_CHECK(end <= shape.ndim); + for (size_t i = begin; i < end; ++i) { ret *= shape.data[i]; } return ret; } +size_t product(const NVTEShape &shape) { + return product(shape, 0, shape.ndim); +} +size_t product(const std::vector shape, size_t begin, size_t end) { + size_t ret = 1; + NVTE_CHECK(end <= shape.size()); + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} -Tensor::Tensor(const NVTEShape &shape, const DType type) { - size_t s = typeToSize(type); - size_t total_size = product(shape) * s; - void *dptr = nullptr; - cpu_data_ = nullptr; - amax_cpu_data_ = nullptr; - scale_cpu_data_ = nullptr; - scale_inv_cpu_data_ = nullptr; - float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr; - if (total_size != 0) { - cudaMalloc((void**)&dptr, total_size); // NOLINT(*) - cudaMemset(dptr, 0, total_size); - cpu_data_ = std::make_unique(total_size); - for (size_t i = 0; i < total_size; ++i) { - cpu_data_[i] = 0; - } +size_t product(const std::vector& shape) { + return product(shape, 0, shape.size()); +} + +size_t DIVUP(const size_t &x, const size_t &y){ + return (((x) + ((y)-1)) / (y)); +} + +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +struct scale_inv_meta { + std::vector shape; + DType type; + size_t type_size; +}; + +NVTEShape convertShape(const std::vector& shape) { + return {shape.data(), shape.size()}; +} + +std::pair get_scales(const NVTEShape& shape, + const NVTEScalingMode scaling_mode) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + scale_inv_meta ret; + ret.shape = {1}; + ret.type = DType::kFloat32; + ret.type_size = sizeof(float); + return {ret, ret}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + auto block_alignment = std::vector{128ul,4ul}; + { + auto alignment = block_alignment[0]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, + static_cast(1)), + alignment) * alignment; + alignment = block_alignment[1]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, + static_cast(32)), + alignment) * alignment; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto alignment = block_alignment[1]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, + static_cast(32)), + alignment) * alignment; + alignment = block_alignment[0]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, + static_cast(1)), + alignment) * alignment; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; } - if (isFp8Type(type)) { + ret_rowwise.type = DType::kFloat8E8M0; + ret_colwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size = sizeof(uint8_t); + ret_colwise.type_size = sizeof(uint8_t); + + return {ret_rowwise, ret_colwise}; + } + + NVTE_ERROR("Invalid scaling mode!"); +} + +Tensor::Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise, const bool columnwise, + const NVTEScalingMode &scaling_mode) { + name_ = name; + const size_t seed = create_seed_from_tensor_name(name); + gen_.seed(seed); + rowwise_ = rowwise; + columnwise_ = columnwise; + size_t s = typeToSize(type); + size_t total_size = product(shape) * s; + void *dptr_rowwise = nullptr; + void *dptr_columnwise = nullptr; + cpu_data_rowwise_ = nullptr; + cpu_data_columnwise_ = nullptr; + amax_cpu_data_ = nullptr; + scale_cpu_data_ = nullptr; + rowwise_scale_inv_cpu_data_ = nullptr; + columnwise_scale_inv_cpu_data_ = nullptr; + float *amax = nullptr, *scale = nullptr; + float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr; + if (columnwise) { + NVTE_CHECK(shape.ndim >= 2); + } + std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), + shape.data[shape.ndim - 1]}; + NVTEShape normalized_shape = convertShape(normalized_shape_v); + + std::vector columnwise_shape_vec; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + // Transpose when tensor scaling + columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); + for (size_t i = 0; i < shape.ndim - 1; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } else { + // Same shape for MX + for (size_t i = 0; i < shape.ndim; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } + const NVTEShape columnwise_shape{columnwise_shape_vec.data(), columnwise_shape_vec.size()}; + + tensor_ = TensorWrapper(scaling_mode); + + if (total_size != 0) { + if (rowwise) { + cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) + cudaMemset(dptr_rowwise, 0, total_size); + cpu_data_rowwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_rowwise_.get(), total_size, 0); + } + if (columnwise) { + cudaMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*) + cudaMemset(dptr_columnwise, 0, total_size); + cpu_data_columnwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_columnwise_.get(), total_size, 0); + } + } + tensor_.set_rowwise_data(dptr_rowwise, type, shape); + tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); + + if (isFp8Type(type)) { + if (is_tensor_scaling(scaling_mode)) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) cudaMemset(scale, 0, sizeof(float)); - cudaMalloc((void**)&scale_inv, sizeof(float)); // NOLINT(*) - cudaMemset(scale_inv, 0, sizeof(float)); - amax_cpu_data_ = std::make_shared(); - *amax_cpu_data_ = 0; - scale_cpu_data_ = std::make_shared(); - *scale_cpu_data_ = 0; - scale_inv_cpu_data_ = std::make_shared(); - *scale_inv_cpu_data_ = 0; + amax_cpu_data_ = std::make_shared(0); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) + if (rowwise) { + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + if (columnwise) { + tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + } else { + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, + tensor_.scaling_mode()); + auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_shape = rowwise_scale_meta.shape; + auto columnwise_scale_shape = colwise_scale_meta.shape; + if (rowwise) { + cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); + rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + } + if (columnwise) { + cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) + cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); + columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + } } - tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv); + } } void Tensor::to_cpu() const { const NVTEShape s = tensor_.shape(); const size_t size = product(s) * typeToSize(tensor_.dtype()); - cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost); + if (rowwise_) { + cudaMemcpy(cpu_data_rowwise_.get(), + tensor_.get_rowwise_data().data_ptr, + size, + cudaMemcpyDeviceToHost); + } + if (columnwise_) { + cudaMemcpy(cpu_data_columnwise_.get(), + tensor_.get_columnwise_data().data_ptr, + size, + cudaMemcpyDeviceToHost); + } if (isFp8Type(dtype())) { - cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float), - cudaMemcpyDeviceToHost); - cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float), - cudaMemcpyDeviceToHost); - cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float), - cudaMemcpyDeviceToHost); + if (is_tensor_scaling(tensor_.scaling_mode())) { + cudaMemcpy(amax_cpu_data_.get(), + tensor_.amax(), + sizeof(float), + cudaMemcpyDeviceToHost); + cudaMemcpy(scale_cpu_data_.get(), + tensor_.scale(), + sizeof(float), + cudaMemcpyDeviceToHost); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), + tensor_.get_rowwise_scale_inv().data_ptr, + scale_size, + cudaMemcpyDeviceToHost); + } + if (columnwise_) { + auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), + tensor_.get_columnwise_scale_inv().data_ptr, + scale_size, + cudaMemcpyDeviceToHost); + } } } void Tensor::from_cpu() const { const NVTEShape s = tensor_.shape(); const size_t size = product(s) * typeToSize(tensor_.dtype()); - cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice); + if (rowwise_) { + cudaMemcpy(tensor_.get_rowwise_data().data_ptr, + cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); + } + if (columnwise_) { + cudaMemcpy(tensor_.get_columnwise_data().data_ptr, + cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); + } if (isFp8Type(dtype())) { - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); - cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); - cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + if (is_tensor_scaling(tensor_.scaling_mode())) { + cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, + rowwise_scale_inv_cpu_data_.get(), scale_size, + cudaMemcpyHostToDevice); + } + if (columnwise_) { + auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, + columnwise_scale_inv_cpu_data_.get(), scale_size, + cudaMemcpyHostToDevice); + } } } void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - *scale_cpu_data_ = scale; - from_cpu(); + if (is_tensor_scaling(tensor_.scaling_mode())) { + *scale_cpu_data_ = scale; + from_cpu(); + } } } void Tensor::set_scale_inv(float scale_inv) { if (isFp8Type(dtype())) { - NVTE_CHECK(scale_inv_cpu_data_); - *scale_inv_cpu_data_ = scale_inv; + if (rowwise_) { + NVTE_CHECK(rowwise_scale_inv_cpu_data_); + } + if (columnwise_) { + NVTE_CHECK(columnwise_scale_inv_cpu_data_); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + if (rowwise_) { + auto num_scales = product(rowwise_scale_meta.shape); + if (num_scales == 1){ + rowwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else{ + std::uniform_int_distribution dis(0, 127); + auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++){ + scale_inv_ptr[i] = dis(gen_); + } + } + } + if (columnwise_) { + auto num_scales = product(colwise_scale_meta.shape); + if (num_scales == 1){ + columnwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else{ + std::uniform_int_distribution dis(0, 127); + auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++){ + scale_inv_ptr[i] = dis(gen_); + } + } + } from_cpu(); } } void Tensor::shareFP8Meta(const Tensor &other) { if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { - tensor_ = TensorWrapper(dptr(), shape(), dtype(), - other.tensor_.amax(), - other.tensor_.scale(), - other.tensor_.scale_inv()); + auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); + auto my_rowwise_data = tensor_.get_rowwise_data(); + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, + static_cast(my_rowwise_data.dtype), + my_rowwise_data.shape); + auto my_columnwise_data = tensor_.get_columnwise_data(); + new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, + static_cast(my_columnwise_data.dtype), + my_columnwise_data.shape); + auto other_amax = other.tensor_.get_amax(); + new_tensor.set_amax(other_amax.data_ptr, + static_cast(other_amax.dtype), + other_amax.shape); + auto other_scale = other.tensor_.get_scale(); + new_tensor.set_scale(other_scale.data_ptr, + static_cast(other_scale.dtype), + other_scale.shape); + auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); + new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, + static_cast(other_row_scale_inv.dtype), + other_row_scale_inv.shape); + auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv(); + new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr, + static_cast(other_col_scale_inv.dtype), + other_col_scale_inv.shape); + tensor_ = std::move(new_tensor); to_cpu(); } } @@ -177,12 +464,14 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } -void compareResults(const std::string &name, const Tensor &test, const void *ref, - double atol, double rtol) { - test.to_cpu(); - const size_t N = product(test.shape()); +void compareResults_sequential(const std::string &name, const Tensor &test, + const void *ref, const bool rowwise, + double atol, double rtol, bool if_on_gpus) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, - const T *test_data = test.cpu_dptr(); + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); for (size_t i = 0; i < N; ++i) { double t = static_cast(test_data[i]); @@ -200,14 +489,84 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref const double cast_mean_m = static_cast(static_cast(mean_m)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); } - ASSERT_FALSE(assertion) << "Error in tensor " << name << std::endl - << "Mismatch at place " << to_string(unravel(i, test.shape())) + std::string direction = rowwise ? "rowwise" : "columnwise"; + ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) << " (" << std::to_string(i) << "): " << t << " vs " << r; + } + ); +} + +template +static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, + const size_t N, const double atol, const double rtol) { + int first_mismatch_idx = N; + + bool is_mismatch_found = false; + #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ + reduction(min: first_mismatch_idx) proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + if (is_mismatch_found) { // early escape of the omp thread + continue; + } + + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion && i < first_mismatch_idx) { + first_mismatch_idx = i; + is_mismatch_found = true; + } + } + return first_mismatch_idx; +} + +void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); + const T *ref_data = reinterpret_cast(ref); + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); + if (i != N) { + const double t = static_cast(test_data[i]); + const double r = static_cast(ref_data[i]); + std::string direction = rowwise ? "rowwise" : "columnwise"; + ASSERT_FALSE(true) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } +void compareResults(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus) { + constexpr bool sequential = false; + if constexpr (sequential) { + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + } else { + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + } +} + void compareResults(const std::string &name, const float test, const float ref, double atol, double rtol) { double t = static_cast(test); @@ -218,6 +577,51 @@ void compareResults(const std::string &name, const float test, const float ref, } + +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol) { + size_t max_mismatches = std::ceil(N * mismatch_rate_tol); + size_t n_mismatches = 0; + std::vector mismatch_indices; + for (int i = 0; i < N; i++){ + bool mismatch = test[i] != ref[i]; + if (mismatch){ + n_mismatches++; + mismatch_indices.push_back(i); + } + if (n_mismatches > max_mismatches){ + std::cout << "Error in " << name << std::endl; + for (auto &index : mismatch_indices) + std::cout << "Mismatch at (" << index << "):" << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << std::endl; + GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol."; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride) +{ + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int idx = i * stride + j; + ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[idx]) << " vs " + << static_cast(ref[idx]) << " at index " << idx; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N) +{ + for (int i = 0; i < N; i++) { + ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << " at index " << i; + } +} + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: @@ -228,6 +632,7 @@ std::pair getTolerances(const DType type) { return {1e-5, 1e-2}; case DType::kFloat8E4M3: case DType::kFloat8E5M2: + case DType::kFloat8E8M0: return {1e-2, 1e-2}; default: NVTE_CHECK("Invalid type!"); @@ -235,29 +640,158 @@ std::pair getTolerances(const DType type) { return {0, 0}; } +template +void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { + #pragma omp parallel proc_bind(spread) + { + std::mt19937 gen_local = *gen; + gen_local.discard(omp_get_thread_num() * 599); + std::uniform_real_distribution<> dis(-2.0, 1.0); + #pragma omp for schedule(static) + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(dis(gen_local)); + } + } + gen->discard(size); +} + void fillUniform(Tensor *t) { - const size_t size = product(t->shape()); - static std::mt19937 gen(12345); + if (t->rowwise()) { + const size_t size = product(t->rowwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { + T *data = t->rowwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); + } + ); + } else { + const size_t size = product(t->columnwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { + T *data = t->columnwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); + } + ); + } std::uniform_real_distribution<> dis(-2.0, 1.0); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { - T *data = t->cpu_dptr(); + t->set_scale_inv(dis(t->gen())); + t->from_cpu(); +} + +template +void fillCase_special(Tensor *t) { + const size_t size = product(t->rowwise_shape()); + const size_t rows = t->rowwise_shape().data[0]; + const size_t cols = t->rowwise_shape().data[1]; + + if constexpr (Case == InputsFillCase::zeros) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { + InputType *data = t->rowwise_cpu_dptr(); for (size_t i = 0; i < size; ++i) { - data[i] = T(dis(gen)); + data[i] = static_cast(0); } - }); - t->set_scale_inv(dis(gen)); + }); + } else { + double minAbs = -2.0; + double maxAbs = 1.0; + if constexpr (Case != InputsFillCase::uniform) { + minAbs = Quantized_Limits::ranges[Case]; + maxAbs = Quantized_Limits::ranges[Case + 1]; + } + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { + InputType *data = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } + } + }); + } + t->set_scale_inv(1.0); t->from_cpu(); } +template +void fillCase(Tensor *t, const InputsFillCase fill_case) { + switch (fill_case) { + case InputsFillCase::uniform: + fillCase_special(t); break; + case InputsFillCase::zeros: + fillCase_special(t); break; + case InputsFillCase::zero_to_minNorm: + fillCase_special(t); break; + case InputsFillCase::minNorm_to_maxNorm: + fillCase_special(t); break; + case InputsFillCase::maxNorm_to_inf: + fillCase_special(t); break; + } +} + +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); + void setRandomScale(Tensor *t) { - static std::mt19937 gen(12345); std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale = dis(gen); + const float scale = dis(t->gen()); t->set_scale(scale); } +void setRandomScaleInv(Tensor *t) { + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float scale_inv = dis(t->gen()); + t->set_scale_inv(scale_inv); +} + bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; +} + +int32_t getDeviceComputeCapability() +{ + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; +} + +size_t first_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + if (shape.size() == 1) return 1; + return product(shape, 0, shape.size() - 1); +} + +size_t last_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + return shape[shape.size() - 1]; +} + +std::array get_scale_tensor_dims(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) { + const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + + const size_t alignment_Y = is_rowwise + ? scale_tensor_alignment_Y_rowwise + : scale_tensor_alignment_Y_colwise; + const size_t alignment_X = is_rowwise + ? scale_tensor_alignment_X_rowwise + : scale_tensor_alignment_X_colwise; + + const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols); + + const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y); + const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X); + return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 4598a7b021..dc515ccb8e 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -6,9 +6,10 @@ #pragma once -#include #include #include +#include +#include #include #include @@ -52,6 +53,7 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using fp8e8m0 = uint8_t; template struct TypeInfo{ @@ -62,7 +64,8 @@ struct TypeInfo{ fp16, bf16, fp8e4m3, - fp8e5m2>; + fp8e5m2, + fp8e8m0>; template struct Helper { @@ -94,10 +97,19 @@ struct TypeInfo{ class Tensor { public: - Tensor(const NVTEShape &shape, const DType type); - - Tensor(const std::vector &shape, const DType type) : - Tensor(NVTEShape{shape.data(), shape.size()}, type) {} + Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + + Tensor(const std::string& name, + const std::vector &shape, + const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} @@ -108,30 +120,82 @@ class Tensor { Tensor& operator=(Tensor &&other) = default; ~Tensor() { - if (tensor_.dptr() != nullptr) { - cudaFree(tensor_.dptr()); + void *data_ptr = tensor_.dptr(); + void *scale_inv = tensor_.scale_inv(); + void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; + void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; + if (columnwise_data_ptr == data_ptr) { + columnwise_data_ptr = nullptr; + } + if (columnwise_scale_inv == scale_inv) { + columnwise_scale_inv = nullptr; + } + if (data_ptr != nullptr) { + cudaFree(data_ptr); + } + if (scale_inv != nullptr) { + cudaFree(scale_inv); + } + if (columnwise_data_ptr != nullptr){ + cudaFree(columnwise_data_ptr); + } + if (columnwise_scale_inv != nullptr){ + cudaFree(columnwise_scale_inv); } } + NVTETensor data() const noexcept { return tensor_.data(); } - const NVTEShape shape() const noexcept { - return tensor_.shape(); + NVTEShape rowwise_shape() const noexcept { + return tensor_.get_rowwise_data().shape; + } + + NVTEShape columnwise_shape() const noexcept { + return tensor_.get_columnwise_data().shape; + } + + NVTEShape rowwise_scale_inv_shape() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().shape; + } + + NVTEShape columnwise_scale_inv_shape() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().shape; + } + + NVTEScalingMode scaling_mode() const noexcept { + return tensor_.scaling_mode(); } DType dtype() const noexcept { return tensor_.dtype(); } - void *dptr() const noexcept { - return tensor_.dptr(); + void *rowwise_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_data().data_ptr; + } + + void *columnwise_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_data().data_ptr; + } + + template + T *rowwise_cpu_dptr() const { + NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return reinterpret_cast(cpu_data_rowwise_.get()); } template - T *cpu_dptr() const { + T *columnwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); - return reinterpret_cast(cpu_data_.get()); + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return reinterpret_cast(cpu_data_columnwise_.get()); } float amax() const { @@ -145,6 +209,7 @@ class Tensor { float scale() const { if(scale_cpu_data_) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -152,52 +217,246 @@ class Tensor { } } - float scale_inv() const { - if(scale_inv_cpu_data_) { - to_cpu(); - return *scale_inv_cpu_data_; + template + T *rowwise_cpu_scale_inv_ptr(){ + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); + } + + template + T *columnwise_cpu_scale_inv_ptr(){ + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); + } + + float rowwise_scale_inv(){ + if(rowwise_scale_inv_cpu_data_) { + float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; + return scale_inv; } else { return 1; } } + bool rowwise() const { + return rowwise_; + } + + bool columnwise() const { + return columnwise_; + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); void set_scale_inv(float scale_inv); void shareFP8Meta(const Tensor &other); + std::mt19937& gen() { return gen_; } + private: TensorWrapper tensor_; - std::unique_ptr cpu_data_; + std::unique_ptr cpu_data_rowwise_; + std::unique_ptr cpu_data_columnwise_; std::shared_ptr amax_cpu_data_; std::shared_ptr scale_cpu_data_; - std::shared_ptr scale_inv_cpu_data_; + std::unique_ptr rowwise_scale_inv_cpu_data_; + std::unique_ptr columnwise_scale_inv_cpu_data_; + bool rowwise_; + bool columnwise_; + std::string name_; + std::mt19937 gen_; +}; + +constexpr uint32_t FP32_EXPONENT_BIAS = 127; +constexpr uint32_t FP32_MANTISSA_BITS = 23; + +// [128,4] rowwise and [4,128] colwise alignment requirement +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + +inline size_t divide_round_up(const size_t N, const size_t M) { + return (N - 1 + M) / M; +} + +inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { + return divide_round_up(N, M) * M; +} + +template +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0; + static constexpr double maxSubnorm = 1.0; + static constexpr double minNorm = 1.0; + static constexpr double maxNorm = 1.0; + static constexpr double artifInf = 1.0; + static constexpr int maxBiasedExponent = 1; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 9); // std::pow(2.0, -9.0); + static constexpr double maxSubnorm = 0.875 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double maxNorm = 448.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 8; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 16); // std::pow(2.0, -16.0); + static constexpr double maxSubnorm = 0.75 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double maxNorm = 57344.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 15; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = std::numeric_limits::denorm_min(); // std::pow(2.0, -149.0); + static constexpr double maxSubnorm = std::numeric_limits::min() + - std::numeric_limits::denorm_min(); // minNormalized - minDenormalized + static constexpr double minNorm = std::numeric_limits::min(); // std::pow(2.0, -126.0); + static constexpr double maxNorm = std::numeric_limits::max(); // (1 - pow(2, -24)) * pow(2, 128) + static constexpr double artifInf = std::numeric_limits::infinity(); + static constexpr int maxBiasedExponentAsFP32 = 255; + static constexpr int maxUnbiasedExponentAsFP32 = 128; +}; + +template +struct Quantized_Limits { + static constexpr double ranges[] = { + 0.0, + Numeric_Traits::minNorm, + Numeric_Traits::maxNorm, + Numeric_Traits::artifInf + }; + static constexpr inline fp32 max() { return static_cast(Numeric_Traits::maxNorm); } + static constexpr inline fp32 max_reciprocal() { return static_cast(1.0 / max()); } + static constexpr inline fp32 emax() { return static_cast(Numeric_Traits::maxExpNorm); } + static constexpr inline fp32 emax_reciprocal() { return static_cast(1.0 / emax()); } + static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits::maxBiasedExponentAsFP32; } + static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits::maxUnbiasedExponentAsFP32; } +}; + +// Input data filling cases +// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats +// with nearest to even rounding per OFP8 specification +enum InputsFillCase { + zero_to_minNorm = 0, // [0, min_normal) + minNorm_to_maxNorm = 1, // [min_normal, max_normal) + maxNorm_to_inf = 2, // [max_normal, inf) + zeros = 3, // {0} + uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0) }; +inline fp8e8m0 float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (std::isnan(val)) { + return 0xFF; + } + if (std::isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +} + +inline float exp2f_rcp(fp8e8m0 biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + +inline float identity(const float x) { return x; } +inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); } +inline float dgelu(const float x) { + const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x)); + return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1 + tanh_out); +} +inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); } +inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); } +inline float qgelu(const float x) { return x * sigmoid(1.702f * x); } +inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); } +inline float relu(const float x) { return fmaxf(0, x); } +inline float drelu(const float x) { return x > 0 ? 1 : 0; } +inline float silu(const float x) { return x * sigmoid(x); } +inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } +inline float srelu(const float x) { return x > 0 ? x * x : 0; } +inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } + size_t typeToSize(DType type); size_t product(const NVTEShape &shape); +size_t product(const std::vector &shape); + +size_t first_dimension(const std::vector &shape); +size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - double atol = 1e-5, double rtol = 1e-8); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol = 0.); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N); + +std::array get_scale_tensor_dims(const size_t rows, const size_t cols, + const size_t block_size_rows, const size_t block_size_cols); std::pair getTolerances(const DType type); void fillUniform(Tensor *t); + +template +void fillCase(Tensor *t, const InputsFillCase fill_case); + void setRandomScale(Tensor *t); +void setRandomScaleInv(Tensor *t); constexpr int THREADS_PER_WARP = 32; const std::string &typeName(DType type); +const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +int32_t getDeviceComputeCapability(); +constexpr int32_t blackwellComputeCapability = 100; + } // namespace test #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ @@ -254,3 +513,47 @@ bool isFp8Type(DType type); default: \ NVTE_ERROR("Invalid type."); \ } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E4M3: \ + { \ + using type = fp8e4m3; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E5M2: \ + { \ + using type = fp8e5m2; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: \ + { \ + using type = float; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat16: \ + { \ + using type = fp16; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kBFloat16: \ + { \ + using type = bf16; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index ffa05f0d66..7540687089 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -8,8 +8,9 @@ add_executable(test_util ../test_common.cu) -target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) -target_compile_options(test_util PRIVATE -O2) +find_package(OpenMP REQUIRED) +target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) +target_compile_options(test_util PRIVATE -O2 -fopenmp) include(GoogleTest) -gtest_discover_tests(test_util) +gtest_discover_tests(test_util DISCOVERY_TIMEOUT 600) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 920f9dc62e..d1558710c7 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -27,9 +27,6 @@ def enable_fused_attn_after_hopper(): """ if get_device_compute_capability(0) >= 90: os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" yield if "NVTE_FUSED_ATTN" in os.environ: del os.environ["NVTE_FUSED_ATTN"] - if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ: - del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index e6ad8ce20c..a67335236d 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -4,14 +4,19 @@ """Test transformer_engine.jax.flax.TransformerLayer""" import os from functools import partial -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import flax import jax import jax.numpy as jnp import pytest -from utils import assert_allclose, assert_tree_like_allclose, sync_params_values +from utils import ( + assert_allclose, + assert_tree_like_allclose, + dtype_tols, + sync_params_values, +) from utils import DecoderLayer as RefDecoderLayer from utils import EncoderLayer as RefEncoderLayer @@ -250,7 +255,13 @@ def _sync_params(self, ref, target): target = sync_params_values(target, ref, self.transformations) return ref, target - def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + def test_forward( + self, + data_shape: Tuple[int], + dtype: jnp.dtype, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> None: """Test only the forward""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) @@ -264,9 +275,16 @@ def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer) test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + tols = dtype_tols(dtype, rtol=rtol, atol=atol) + assert_allclose(ref_out, test_out, **tols) - def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + def test_backward( + self, + data_shape: Tuple[int], + dtype: jnp.dtype, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> None: """Test forward and backward through value_and_grad()""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) @@ -302,11 +320,12 @@ def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): inputs, test_masks, test_params, test_others, test_layer ) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) + tols = dtype_tols(dtype, rtol=rtol, atol=atol) + assert_allclose(ref_out, test_out, **tols) + assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols) _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads) - assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol) + assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols) class EncoderRunner(BaseRunner): @@ -418,12 +437,12 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5) + self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) + self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_format", FP8_FORMATS) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 9cb02bc555..554def2c3f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1387,18 +1387,26 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): def dtype_tols( dtype: Union[DType, TEDType, np.dtype], reference_value: float = 1.0, + rtol: Optional[float] = None, + atol: Optional[float] = None, ) -> Dict[str, float]: """Expected numerical tolerance for a data type. Args: dtype: data type. reference_value: reference value (default: 1). + rtol: override for relative tolerance estimate + atol: override for absolute tolerance estimate Returns: Dictionary with "rtol" and "atol" as keys """ + # Return immediately if tolerances are fully specified + if rtol is not None and atol is not None: + return {"rtol": rtol, "atol": atol} + # Convert to JAX dtype if needed if isinstance(dtype, TEDType): dtype = { @@ -1416,7 +1424,11 @@ def dtype_tols( # Expect bit-wise accuracy for integer dtypes if not jnp.issubdtype(dtype, jnp.floating): - return dict(rtol=0, atol=0) + if rtol is None: + rtol = 0.0 + if atol is None: + atol = 0.0 + return {"rtol": rtol, "atol": atol} # Estimate floating-point error finfo = jnp.finfo(dtype) @@ -1429,10 +1441,11 @@ def dtype_tols( spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min) ulp = max(spacing_high.item(), spacing_low.item()) - return dict( - rtol=eps_relaxed, - atol=max(ulp, eps_relaxed), - ) + if rtol is None: + rtol = eps_relaxed + if atol is None: + atol = max(ulp, eps_relaxed) + return {"rtol": rtol, "atol": atol} def sync_params_values(dst, src, transformations, sep="/"): diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py deleted file mode 100644 index f262f1a1d4..0000000000 --- a/tests/paddle/dist_launcher.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Helper functions to launch distributed tests""" - -import copy -import os -from pathlib import Path -import subprocess -import time -import unittest - -try: - from paddle.base import core -except ImportError: - from paddle.fluid import core -from paddle.distributed.utils.launch_utils import ( - TrainerProc, - find_free_ports, - get_cluster, - watch_local_trainers, -) - -__all__ = ["TestDistributed"] - - -def get_cluster_from_args(selected_gpus): - """Get node information from selected GPUs""" - cluster_node_ips = "127.0.0.1" - node_ip = "127.0.0.1" - - node_ips = [x.strip() for x in cluster_node_ips.split(",")] - - node_ips.index(node_ip) - - free_ports = None - - free_ports = find_free_ports(len(selected_gpus)) - if free_ports is not None: - free_ports = list(free_ports) - - trainer_endpoints = [] - for ip in node_ips: - trainer_endpoints.append([f"{ip}:{port}" for port in free_ports]) - return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) - - -def get_gpus(selected_gpus): - """Get selected GPU string""" - selected_gpus = [x.strip() for x in selected_gpus.split(",")] - return selected_gpus - - -def start_local_trainers( - cluster, - pod, - training_script, - training_script_args, - allocator_strategy="auto_growth", -): - """Launch trainers""" - current_env = copy.copy(os.environ.copy()) - # paddle broadcast ncclUniqueId use socket, and - # proxy maybe make trainers unreachable, so delete them. - # if we set them to "", grpc will log error message "bad uri" - # so just delete them. - current_env.pop("http_proxy", None) - current_env.pop("https_proxy", None) - - procs = [] - for t in pod.trainers: - proc_env = { - "FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]), - "PADDLE_TRAINER_ID": f"{t.rank}", - "PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", - "PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}", - "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), - "PYTHONPATH": str(Path(__file__).resolve().parent), - } - - proc_env["FLAGS_allocator_strategy"] = allocator_strategy - if allocator_strategy == "auto_growth": - proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1" - - current_env.update(proc_env) - - print(f"trainer proc env:{current_env}") - - if os.getenv("WITH_COVERAGE", "OFF") == "ON": - cmd = "python -m coverage run --branch -p " + training_script - else: - cmd = "python -u " + training_script - - print(f"start trainer proc:{cmd} env:{proc_env}") - - fn = None - - proc = subprocess.Popen( - cmd.split(" ") + training_script_args, env=current_env - ) # pylint: disable=consider-using-with - - tp = TrainerProc() - tp.proc = proc - tp.rank = t.rank - tp.log_fn = fn - tp.cmd = cmd - - procs.append(tp) - - return procs - - -class TestDistributed(unittest.TestCase): - """Base class for distributed test""" - - @staticmethod - def run_2gpu( - target_file_name, - allocator_strategy="auto_growth", - ): - """Run target file in subprocesses""" - if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0: - return - - selected_gpus = get_gpus("0,1") - cluster = None - pod = None - - cluster, pod = get_cluster_from_args(selected_gpus) - - procs = start_local_trainers( - cluster, - pod, - allocator_strategy=allocator_strategy, - training_script=target_file_name, - training_script_args=[], - ) - - while True: - alive = watch_local_trainers(procs, cluster.trainers_endpoints()) - - if not alive: - print(f"Local procs complete, POD info:{pod}") - break - time.sleep(3) diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py deleted file mode 100644 index 3e0a6d2bac..0000000000 --- a/tests/paddle/parallel_tests/amax_reduction.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -def assert_allclose_across_ranks(tensor, group=None): - """Assert tensor is identical in all ranks""" - gathered_list = [] - paddle.distributed.all_gather(gathered_list, tensor, group=group) - assert len(gathered_list) > 1 - for gathered_tensor in gathered_list: - assert_allclose(tensor, gathered_tensor) - - -class TestAmaxReduction(unittest.TestCase): - """Tests Amax reduction""" - - def setUp(self): - self.data_parallel_size = 2 - self.init_dist_env() - self.global_dtype = "bfloat16" - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": self.data_parallel_size, - "mp_degree": 1, - "pp_degree": 1, - } - fleet.init(is_collective=True, strategy=strategy) - - def test_amax_reduction(self): - """Tests column parallel linear""" - set_random_seed(1024) - layer1 = te.Linear(16, 16) - layer2 = te.Linear(16, 16) - model = paddle.nn.Sequential(layer1, layer2) - model = fleet.distributed_model(model) - - rank_id = paddle.distributed.get_rank() - set_random_seed(rank_id) - - optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters()) - optimizer = fleet.distributed_optimizer(optimizer) - - def train_one_step(layer, inp, optimizer): - inp = paddle.to_tensor(inp) - inp.stop_gradient = False - out = layer(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([16, 16], self.global_dtype) - with te.fp8_autocast(enabled=True): - train_one_step(model, inp, optimizer) - - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1]) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1]) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1]) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1]) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/attention_tp.py b/tests/paddle/parallel_tests/attention_tp.py deleted file mode 100644 index c0ffa288ee..0000000000 --- a/tests/paddle/parallel_tests/attention_tp.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Transformer layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks -import transformer_engine.paddle as te - - -class TestAttentionTp(unittest.TestCase): - """Tests MultiHeadAttention layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-3 - self.atol = 5e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False): - inp, mask = inp_list - if sequence_parallel: - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - with te.fp8_autocast(enabled=fp8_enabled): - out = layer(input_parallel, mask) - if sequence_parallel: - total_out = mp_ops._c_concat(out, group=self.tp_group) - total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss, total_out - - def test_parallel_layer(self): - """Tests parallel Transformer""" - set_random_seed(1024) - common_args = ( - self.hidden_size, - self.num_heads, - ) - common_kwargs = { - "layernorm_epsilon": self.eps, - "attention_dropout": 0.0, - "attn_mask_type": self.mask_type, - "attention_type": "self", - "tp_group": self.tp_group, - "input_layernorm": True, - } - - layer_tp = te.MultiHeadAttention( - *common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False) - - def _get_total_weight(local_weight, tp_group, axis, interleave=False): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - if interleave: - # Due to the interleaved qkv layout, need to concat on num_head - # dimension for column parallel linear in MultiHeadAttention layer - assert axis == 0 - assert [ - 3 * self.hidden_size // self.world_size, - self.hidden_size, - ] == partial_weight.shape - local_num_head = self.num_heads // self.world_size - for idx, _ in enumerate(total_weight): - total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size] - ) - total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) - else: - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - def _get_weight(obj, weight_names): - for name in weight_names: - obj = getattr(obj, name) - return obj - - def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False): - weight_src = _get_weight(layer_src, weight_names) - weight_dst = _get_weight(layer_dst, weight_names) - if partition_mode is None: - total_weight = weight_src - elif partition_mode == "column": - total_weight = _get_total_weight( - weight_src, tp_group=self.tp_group, axis=0, interleave=interleave - ) - elif partition_mode == "row": - total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) - else: - raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert ( - weight_dst.shape == total_weight.shape - ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." - weight_dst.copy_(total_weight, True) - - copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"]) - copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True) - copy_weight(layer_tp, layer_single, "row", ["proj", "weight"]) - - if self.sequence_parallel: - register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) - - optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD( - learning_rate=0.01, parameters=layer_single.parameters() - ) - - layer_tp = fleet.distributed_model(layer_tp) - optimizer_tp = fleet.distributed_optimizer(optimizer_tp) - - for _ in range(5): - inp = paddle.uniform( - [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype - ) - mask = paddle.zeros( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" - ) - loss_tp, out_tp = self._train_one_step( - layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel - ) - loss_single, out_single = self._train_one_step( - layer_single, [inp, mask], optimizer_single, self.fp8 - ) - assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) - assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) - - -class TestAttentionTpFp8(TestAttentionTp): - """Tests MultiHeadAttention layer with model parallel in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestAttentionSp(TestAttentionTp): - """Tests MultiHeadAttention layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-3 - self.atol = 5e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestAttentionSpFp8(TestAttentionTp): - """Tests MultiHeadAttention layer with sequence parallel in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 1e-1 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py deleted file mode 100644 index 21d08a8ef3..0000000000 --- a/tests/paddle/parallel_tests/group_sharding.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for group sharding""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( - DygraphShardingOptimizer, -) - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -class TestGroupSharding(unittest.TestCase): - """Tests group sharding""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def set_attr(self): - """Set test configs""" - self.sharding_degree = 2 - self.global_dtype = "float32" - self.rtol = 1e-5 - self.atol = 1e-5 - self.batch_size = 16 - self.in_channels = 16 - self.out_channels = 32 - self.fp8 = False - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": 1, - "sharding_degree": self.sharding_degree, - } - self.strategy = strategy - fleet.init(is_collective=True, strategy=strategy) - - def _get_model_and_optimizer(self, model, stage): - if stage == 1: - optimizer = DygraphShardingOptimizer( - paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()), - fleet.get_hybrid_communicate_group(), - ) - model = fleet.distributed_model(model) - optimizer = fleet.distributed_optimizer(optimizer) - elif stage in [2, 3]: - optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) - group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() - - class ShardingLevel: # pylint: disable=too-few-public-methods, - """Paddle sharding options""" - - kStage1 = "os" - kStage2 = "os_g" - kStage3 = "p_g_os" - - level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 - model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( - model=model, - optimizer=optimizer, - level=level, - group=group, - segment_size=256, - ) - else: - raise ValueError(f"Stage {stage} not supported") - return model, optimizer - - def test_group_sharding_stage1(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - assert ( - len(optimizer_te.state_dict()) == 4 - ), "Expect each rank to hold 4 optimizer state entries." - - def test_group_sharding_stage2(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - # Check gradients are split to different trainers - if rank_id == 0: - assert model.bias.grad is None and model.weight.grad is not None - elif rank_id == 1: - assert model.weight.grad is None and model.bias.grad is not None - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - assert ( - len(optimizer_te.state_dict()) == 4 - ), "Expect each rank to hold 4 optimizer state entries." - - def test_group_sharding_stage3(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - for name, value in optimizer_te.state_dict().items(): - if name.endswith("w_0_moment1_0"): - assert ( - value.numel() == self.in_channels * self.out_channels // self.sharding_degree - ), "Expect optimizer state to be sharded across trainers." - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py deleted file mode 100644 index 96070a03c5..0000000000 --- a/tests/paddle/parallel_tests/layernorm_linear_tp.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for LayerNormLinear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLayerNormLinearTp(unittest.TestCase): - """Tests LayerNormLinear layer with column/row parallelism in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - total_out = mp_ops._c_concat(out, group=self.tp_group) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_column_parallel_layer(self): - """Tests column parallel LayerNormLinear""" - set_random_seed(1024) - layer_te = te.LayerNormLinear( - self.in_features, - self.out_features, - eps=self.eps, - parallel_mode="column", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.LayerNormLinear( - self.in_features, - self.out_features, - eps=self.eps, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=0) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] - ) - assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=True, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): - """Tests LayernormLinear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestLayerNormLinearSp(TestLayerNormLinearTp): - """Tests LayernormLinear layer with sequence parallelism""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): - """Tests LayernormLinear layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py deleted file mode 100644 index 9ec09c7e7a..0000000000 --- a/tests/paddle/parallel_tests/layernorm_mlp_tp.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for LayerNormMLP layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLayerNormMLPTp(unittest.TestCase): - """Tests LayerNormMLP layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - # Need to concat on the first dim, while _c_concat concats on the last dim - total_out = mp_ops._c_concat(out.T, group=self.tp_group).T - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_parallel_layer(self): - """Tests parallel LayerNormMLP""" - set_random_seed(1024) - layer_te = te.LayerNormMLP( - hidden_size=self.hidden_size, - ffn_hidden_size=self.ffn_hidden_size, - eps=self.eps, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.LayerNormMLP( - hidden_size=self.hidden_size, - ffn_hidden_size=self.ffn_hidden_size, - eps=self.eps, - set_parallel_mode=False, - backend="paddle", - ) - - def _get_total_weight(local_weight, tp_group, axis): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - # Get total weight - total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0) - total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1) - layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) - layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) - - assert_shape( - layer_te.fc1_weight, - [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size], - ) - assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) - assert_shape( - layer_te.fc2_weight, - [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size], - ) - assert_shape(layer_te.fc2_bias, [self.hidden_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=self.sequence_parallel, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with tensor parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestLayerNormMLPSp(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py deleted file mode 100644 index 68271e52e7..0000000000 --- a/tests/paddle/parallel_tests/linear_pp.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in pipeline parallel""" - -import unittest - -import numpy as np - -import paddle -from paddle.distributed import fleet - -from paddle.distributed.fleet.meta_parallel import ( - LayerDesc, - PipelineLayer, -) - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -class TELinear(te.Linear): - """To pass is_first_microbatch""" - - def __init__(self, *args, **kwargs): - assert "accumulate_steps" in kwargs - self.accumulate_steps = kwargs["accumulate_steps"] - del kwargs["accumulate_steps"] - self._micro_batch_id = 0 - super().__init__(*args, **kwargs) - - def forward(self, *args, **kwargs): - kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0 - if paddle.is_grad_enabled() and self.training: - self._micro_batch_id += 1 - return super().forward(*args, **kwargs) - - -class TEPipelineModel(PipelineLayer): - """Model for pipeline parallel test""" - - def __init__( - self, - in_features, - hidden_features, - weight_attrs, - use_te=True, - use_fp8=False, - accumulate_steps=1, - **kwargs, - ): - self.in_features = in_features - self.hidden_features = hidden_features - self.fp8 = use_fp8 - hcg = fleet.get_hybrid_communicate_group() - self.dp_group = hcg.get_data_parallel_group() - - Linear = TELinear if use_te else paddle.nn.Linear - extra_kwargs = {} - if use_te: - extra_kwargs["accumulate_steps"] = accumulate_steps - - model_desc = [ - LayerDesc( - Linear, - self.in_features, - self.hidden_features, - weight_attr=weight_attrs[0], - **extra_kwargs, - ), - LayerDesc( - Linear, - self.hidden_features, - self.in_features, - weight_attr=weight_attrs[1], - **extra_kwargs, - ), - ] - super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) - - def forward(self, *args, **kwargs): - with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group): - return super().forward(*args, **kwargs) - - -class StandaloneModel(paddle.nn.Layer): - """Model for pipeline parallel test""" - - def __init__(self, in_features, hidden_features, weight_attrs): - super().__init__() - self.in_features = in_features - self.hidden_features = hidden_features - Linear = paddle.nn.Linear - self.layer = paddle.nn.Sequential( - Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), - Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), - ) - self.loss = paddle.nn.CrossEntropyLoss() - - def forward(self, inp): - out = self.layer(inp[0]) - loss = self.loss(out, inp[1]) - return loss - - -class TestLinearPipelineParallel(unittest.TestCase): - """Tests Linear layer with pipeline parallel""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.pipeline_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": self.pipeline_parallel_size, - } - self.accumulate_steps = self.batch_size // self.micro_batch_size - strategy.pipeline_configs = { - "accumulate_steps": self.accumulate_steps, - "micro_batch_size": self.micro_batch_size, - } - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - - def set_attr(self): - """Set test configs""" - self.batch_size = 32 - self.micro_batch_size = 16 - self.in_features = 32 - self.hidden_features = 64 - self.global_dtype = "float32" - self.rtol = 1e-5 - self.atol = 1e-5 - self.iter = 10 - self.fp8 = False - - def test_pipeline_train(self): - """Test pipeline parallel training""" - set_random_seed(1024) - np.random.seed(1024) - - weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) - weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) - weight_attrs = [ - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)), - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)), - ] - weight_attrs_transposed = [ - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)), - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)), - ] - - pipe_model = TEPipelineModel( - self.in_features, - self.hidden_features, - weight_attrs_transposed, - use_te=True, - use_fp8=self.fp8, - seg_method="layer:Linear", - num_stages=self.pipeline_parallel_size, - accumulate_steps=self.accumulate_steps, - ) - - # Check if model is split across ranks as expected - for name, sublayer in pipe_model.named_sublayers(): - if name in ("_loss_fn", "shared_layers"): - continue - if self.rank == 0: - assert tuple(sublayer.weight.shape) == weight1_np.T.shape, ( - f"Shape does not match, expect: {weight1_np.T.shape}, " - f"actual: {tuple(sublayer.weight.shape)}" - ) - elif self.rank == 1: - assert tuple(sublayer.weight.shape) == weight2_np.T.shape, ( - f"Shape does not match, expect: {weight2_np.T.shape}, " - f"actual: {tuple(sublayer.weight.shape)}" - ) - - standalone_model = StandaloneModel( - self.in_features, - self.hidden_features, - weight_attrs, - ) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) - optimizer_pd = paddle.optimizer.SGD( - learning_rate=0.1, parameters=standalone_model.parameters() - ) - - pipe_model = fleet.distributed_model(pipe_model) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - def train_one_step(layer, inp, optimizer): - loss = layer(inp) - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for i in range(self.iter): - inp = paddle.to_tensor( - np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype - ) - label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) - loss_te = pipe_model.train_batch([inp, label], optimizer_te) - loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) - print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}") - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - -class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): - """Tests Linear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 32 - self.micro_batch_size = 16 - self.in_features = 32 - self.hidden_features = 64 - self.global_dtype = "float32" - self.rtol = 5e-2 - self.atol = 5e-2 - self.iter = 10 - self.fp8 = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py deleted file mode 100644 index 1a42d6c621..0000000000 --- a/tests/paddle/parallel_tests/linear_tp.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLinearTp(unittest.TestCase): - """Tests Linear layer with column/row parallelism in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - total_out = mp_ops._c_concat(out, group=self.tp_group) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_column_parallel_layer(self): - """Tests column parallel linear""" - set_random_seed(1024) - layer_te = te.Linear( - self.in_features, - self.out_features, - parallel_mode="column", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.Linear( - self.in_features, - self.out_features, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=0) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] - ) - assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=True, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - def test_row_parallel_layer(self): - """Tests row parallel linear""" - set_random_seed(1024) - layer_te = te.Linear( - self.in_features, - self.out_features, - parallel_mode="row", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.Linear( - self.in_features, - self.out_features, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=1) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size] - ) - assert_shape(layer_te.bias, [self.out_features]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="column", - gather_output=self.sequence_parallel, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLinearTpFP8(TestLinearTp): - """Tests Linear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.fp8 = True - self.sequence_parallel = False - - -class TestLinearSp(TestLinearTp): - """Tests Linear layer with sequence parallelism""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLinearSpFP8(TestLinearTp): - """Tests Linear layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py deleted file mode 100644 index 5fc3e7ddf3..0000000000 --- a/tests/paddle/parallel_tests/transformer_tp.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Transformer layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks -import transformer_engine.paddle as te - - -class TestTransformerTp(unittest.TestCase): - """Tests Transformer layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False): - inp, mask = inp_list - if sequence_parallel: - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - with te.fp8_autocast(enabled=fp8_enabled): - out = layer(input_parallel, mask) - if sequence_parallel: - total_out = mp_ops._c_concat(out, group=self.tp_group) - total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss, total_out - - def test_parallel_layer(self): - """Tests parallel Transformer""" - set_random_seed(1024) - common_args = [ - self.hidden_size, - self.ffn_hidden_size, - self.num_heads, - ] - common_kwargs = { - "layernorm_epsilon": self.eps, - "hidden_dropout": 0.0, - "attention_dropout": 0.0, - "self_attn_mask_type": self.mask_type, - "layer_type": self.layer_type, - } - layer_tp = te.TransformerLayer( - *common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) - - def _get_total_weight(local_weight, tp_group, axis, interleave=False): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - if interleave: - # Due to the interleaved qkv layout, need to concat on num_head - # dimension for column parallel linear in MultiHeadAttention layer - assert axis == 0 - assert [ - 3 * self.hidden_size // self.world_size, - self.hidden_size, - ] == partial_weight.shape - local_num_head = self.num_heads // self.world_size - for idx, _ in enumerate(total_weight): - total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size] - ) - total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) - else: - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - def _get_weight(obj, weight_names): - for name in weight_names: - obj = getattr(obj, name) - return obj - - def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False): - weight_src = _get_weight(layer_src, weight_names) - weight_dst = _get_weight(layer_dst, weight_names) - if partition_mode is None: - total_weight = weight_src - elif partition_mode == "column": - total_weight = _get_total_weight( - weight_src, tp_group=self.tp_group, axis=0, interleave=interleave - ) - elif partition_mode == "row": - total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) - else: - raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert ( - weight_dst.shape == total_weight.shape - ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." - weight_dst.copy_(total_weight, True) - - copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"]) - copy_weight( - layer_tp, - layer_single, - "column", - ["self_attention", "layernorm_qkv", "weight"], - interleave=True, - ) - copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"]) - copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"]) - copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"]) - copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"]) - - if self.sequence_parallel: - register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) - - optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD( - learning_rate=0.01, parameters=layer_single.parameters() - ) - - layer_tp = fleet.distributed_model(layer_tp) - optimizer_tp = fleet.distributed_optimizer(optimizer_tp) - - for _ in range(5): - inp = paddle.uniform( - [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype - ) - mask = paddle.zeros( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" - ) - loss_tp, out_tp = self._train_one_step( - layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel - ) - loss_single, out_single = self._train_one_step( - layer_single, [inp, mask], optimizer_single, self.fp8 - ) - assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) - assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) - - -class TestTransformerTpFp8(TestTransformerTp): - """Tests Transformer layer with tensor parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 0.5 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestTransformerSp(TestTransformerTp): - """Tests Transformer layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestTransformerSpFp8(TestTransformerSp): - """Tests Transformer layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 0.5 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/recompute_tests/recompute_transformer_encoder.py b/tests/paddle/recompute_tests/recompute_transformer_encoder.py deleted file mode 100644 index e753f750c5..0000000000 --- a/tests/paddle/recompute_tests/recompute_transformer_encoder.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TransformerLayer encoder recompute""" - -import sys -import paddle -import transformer_engine.paddle as te - - -class Net(paddle.nn.Layer): - """Network use for recompute testing""" - - def __init__(self, layers): - super().__init__() - self.layers = layers - - def forward(self, inp, mask, enable_recompute, use_reentrant): - for layer in self.layers: - if enable_recompute: - out = te.recompute(layer, inp, mask, use_reentrant=use_reentrant) - else: - out = layer(inp, mask) - return out - - -def main(): - """Main function""" - paddle.seed(10) - batch_size = 16 - hidden_size = 4096 - num_heads = 32 - ffn_hidden_size = 16384 - q_seqlen = 512 - kv_seqlen = 512 - num_layers = 4 - enable_recompute = int(sys.argv[1]) - use_reentrant = int(sys.argv[2]) - - layers = paddle.nn.LayerList( - [ - te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layer_type="encoder", - ) - for _ in range(num_layers) - ] - ) - model = Net(layers) - - optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters()) - - for _ in range(10): - inp = paddle.uniform([batch_size, q_seqlen, hidden_size]) - inp.stop_gradient = False - mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool") - with te.fp8_autocast(enabled=True): - out = model(inp, mask, enable_recompute, use_reentrant) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - - print("Loss: ", float(loss)) - print("Peak memory: ", paddle.device.cuda.max_memory_allocated(0)) - - -if __name__ == "__main__": - main() diff --git a/tests/paddle/test_install.py b/tests/paddle/test_install.py deleted file mode 100644 index 1c317584ed..0000000000 --- a/tests/paddle/test_install.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test basic installation of Paddle extensions""" - - -def test_import(): - """ - Test if Paddle extension can be imported normally - """ - import transformer_engine.paddle # pylint: disable=unused-import diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py deleted file mode 100644 index fbd6c61ad7..0000000000 --- a/tests/paddle/test_layers.py +++ /dev/null @@ -1,1663 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Layer-level APIs""" - -import os -from utils import assert_allclose, is_fused_attention_supported - -import paddle -import pytest - -from transformer_engine.common.recipe import DelayedScaling -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast - -is_fp8_supported, reason = is_fp8_available() -LINEAR_CASES = [(16, 16, 32), (32, 32, 64)] -NORM_CASES = [(16, 32), (256, 1024)] - - -@pytest.fixture(autouse=True) -def setup(): - """Setup random seed before each test""" - paddle.seed(10) - yield - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("use_fp8", [True, False]) -def test_checkpoint(use_fp8): - """Test checkpoint save / load""" - bs = 16 - in_features = 16 - out_features = 32 - file_name = "model.pdparams" - input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32") - model = te.Linear(in_features, out_features) - model_loaded = te.Linear(in_features, out_features) - # Populate amax_history - with fp8_autocast(enabled=False, calibrating=True): - _ = model(input_tensor) - # Save model - paddle.save(model.state_dict(), file_name) - # Get ref output - with fp8_autocast(enabled=use_fp8): - out_ref = model(input_tensor) - # Load model - model_loaded.set_state_dict(paddle.load(file_name)) - if os.path.exists(file_name): - os.remove(file_name) - # Get actual output - with fp8_autocast(enabled=use_fp8): - out = model_loaded(input_tensor) - - assert_allclose(out, out_ref) - - -def calc_output_and_grad(layer, x, dy): - """ - Calculate forward and backward pass - """ - inp = paddle.to_tensor(x) - inp.stop_gradient = x.stop_gradient - y = layer(inp) - y.backward(dy) - - return y, inp.grad if not inp.stop_gradient else None - - -@staticmethod -def calc_output_and_grad_ln_out(layer, x, dy, return_ln_out=False): - """ - Calculate forward and backward pass for layernorm - """ - inp = paddle.to_tensor(x) - inp.stop_gradient = x.stop_gradient - outputs = layer(inp) - ln_out = None - if return_ln_out: - y, ln_out = outputs - else: - y = outputs - y.backward(dy) - - return y, ln_out, inp.grad if not inp.stop_gradient else None - - -class TestLinear: - """ - Tests for Linear layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - def test_linear_bf16( - bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype - ): - """ - Test BF16 Linear - """ - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False) - layer_pd = te.Linear( - in_features, out_features, bias_attr=None if has_bias else False, backend="paddle" - ) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) - out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - def test_linear_fp8( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - activation_dtype, - ): - """ - Test FP8 Linear - """ - rtol = 0.1 - atol = 0.5 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.Linear( - in_features=in_features, - out_features=out_features, - bias_attr=None if has_bias else False, - ) - layer_pd = te.Linear( - in_features=in_features, - out_features=out_features, - bias_attr=None if has_bias else False, - backend="paddle", - ) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) - out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch): - """ - Test FP8 Linear - """ - rtol = 0.1 - atol = 0.1 - - recipe = DelayedScaling() - - paddle.set_default_dtype(activation_dtype) - layer_cached = te.Linear( - in_features=in_features, - out_features=out_features, - ) - layer_normal = te.Linear( - in_features=in_features, - out_features=out_features, - ) - layer_cached.weight.copy_(layer_normal.weight, True) - layer_cached.bias.copy_(layer_normal.bias, True) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol - ) - - -@pytest.mark.parametrize("bs,hidden_size", NORM_CASES) -@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) -@pytest.mark.parametrize("no_dgrad", [True, False]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) -def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype): - """ - Test BF16 LayerNorm - """ - eps = 1e-3 - rtol = 1e-2 - atol = 1e-2 - - x = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - x.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False) - layer_pd = te.LayerNorm( - hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle" - ) - layer_pd.weight.copy_(layer_te.weight, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out) - out, grad_input = calc_output_and_grad(layer_te, x, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - - -class TestLayerNormLinear: - """ - Tests for LayerNormLinear layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - def test_layernorm_linear_bf16( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - return_ln_out, - activation_dtype, - normalization, - ): - """ - Test BF16 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - layer_te = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - def test_layernorm_linear_fp8( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - return_ln_out, - activation_dtype, - normalization, - ): - """ - Test FP8 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 0.1 - atol = 0.75 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - layer_te = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_layernorm_linear_fp8_microbatch( - bs, in_features, out_features, activation_dtype, num_microbatch - ): - """ - Test FP8 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - eps = 1e-3 - rtol = 0.5 - atol = 0.5 - - recipe = DelayedScaling() - - layer_cached = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - ) - - layer_normal = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - ) - - layer_cached.ln_weight.copy_(layer_normal.ln_weight, True) - layer_cached.ln_bias.copy_(layer_normal.ln_bias, True) - layer_cached.weight.copy_(layer_normal.weight, True) - layer_cached.bias.copy_(layer_normal.bias, True) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol - ) - - -class TestLayerNormMLP: - """ - Test LayerNormMLP Layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) - def test_layernorm_mlp_bf16( - bs, - hidden_size, - ffn_hidden_size, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - return_ln_out, - activation_dtype, - normalization, - activation, - ): - """ - Tests for TestLayerNormMLP layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - layer_te = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - layer_pd = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) - layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) - if has_bias: - layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True) - layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True) - - layer_te.fc1_weight.stop_gradient = no_wgrad - layer_te.fc2_weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.fc1_weight.stop_gradient = no_wgrad - layer_pd.fc2_weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.fc1_bias.stop_gradient = no_dbias - layer_te.fc2_bias.stop_gradient = no_dbias - layer_pd.fc1_bias.stop_gradient = no_dbias - layer_pd.fc2_bias.stop_gradient = no_dbias - - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose( - layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol - ) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose( - layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol - ) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) - def test_layernorm_mlp_fp8( - bs, - hidden_size, - ffn_hidden_size, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - return_ln_out, - activation_dtype, - normalization, - activation, - ): - """ - Test FP8 LayerNormMLP Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 0.1 - atol = 0.75 - - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - layer_te = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) - layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) - if has_bias: - layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True) - layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True) - - layer_te.fc1_weight.stop_gradient = no_wgrad - layer_te.fc2_weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.fc1_weight.stop_gradient = no_wgrad - layer_pd.fc2_weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.fc1_bias.stop_gradient = no_dbias - layer_te.fc2_bias.stop_gradient = no_dbias - layer_pd.fc1_bias.stop_gradient = no_dbias - layer_pd.fc2_bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose( - layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol - ) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose( - layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol - ) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_layernorm_mlp_fp8_microbatch( - bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch - ): - """ - Test FP8 LayerNormMLP Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 1e-5 - atol = 1e-5 - eps = 1e-3 - - recipe = DelayedScaling() - - layer_cached = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - ) - - layer_normal = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - ) - layer_normal.ln_weight.copy_(layer_cached.ln_weight, True) - layer_normal.ln_bias.copy_(layer_cached.ln_bias, True) - layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True) - layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True) - layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True) - layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True) - - # Calibration to make sure weight scale is the same - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_cached(input_tensor) - - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_normal(input_tensor) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol - ) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("attn_type", ["self", "cross"]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("deterministic", [True, False]) -def test_dot_product_attention( - bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic -): - """ - Test DotProductAttention Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 1e-4 - atol = 2e-2 - head_size = hidden_size // num_heads - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - attn_q_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size) - ).astype(math_dtype) - attn_k_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) - ).astype(math_dtype) - attn_v_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) - ).astype(math_dtype) - - q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32") - kv_actual_seqlen = ( - paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32") - if attn_type == "cross" - else q_actual_seqlen - ) - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - head_size = hidden_size // num_heads - - if deterministic: - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" - - layer_te = te.DotProductAttention( - num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend="transformer_engine", - ) - layer_pd = te.DotProductAttention( - num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend="paddle", - ) - - def calc_attn_output_and_grad(layer, q, k, v, mask, dout): - _q = paddle.to_tensor(q, stop_gradient=False) - _k = paddle.to_tensor(k, stop_gradient=False) - _v = paddle.to_tensor(v, stop_gradient=False) - - out = layer(_q, _k, _v, mask) - out.backward(dout) - return out, _q.grad, _k.grad, _v.grad - - out, q_grad, k_grad, v_grad = calc_attn_output_and_grad( - layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad( - layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - valid_out_ref = paddle.full_like(out_ref, 0) - for i in range(0, bs): - valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :] - - valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) - valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) - valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) - for i in range(0, bs): - valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[ - i, 0 : q_actual_seqlen[i], :, : - ] - valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[ - i, 0 : kv_actual_seqlen[i], :, : - ] - valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[ - i, 0 : kv_actual_seqlen[i], :, : - ] - - assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol) - assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) - assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) - assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) - if deterministic: - out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad( - layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - assert_allclose(out, out2, rtol=1e-12, atol=1e-12) - assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12) - assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12) - assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12) - os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("output_layernorm", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) -def test_transformer_encoder_layer( - bs, - hidden_size, - num_heads, - num_gqa_groups, - ffn_hidden_size, - has_bias, - no_dbias, - no_wgrad, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - output_layernorm, - return_layernorm_output, - normalization, -): - """ - Test Transformer Encoder Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 5e-2 - atol = 5e-2 - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="encoder", - normalization=normalization, - backend="transformer_engine", - ) - layer_pd = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="encoder", - normalization=normalization, - backend="paddle", - ) - - # MultiHeadAttention params - if output_layernorm: - layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) - layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) - layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.qkv.bias.stop_gradient = no_dbias - else: - layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True - ) - layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True - ) - layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True - ) - layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True - ) - layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - - layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) - layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad - layer_te.self_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) - layer_pd.self_attention.proj.bias.stop_gradient = no_dbias - layer_te.self_attention.proj.bias.stop_gradient = no_dbias - - # LayerNorm MLP params - layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) - layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) - layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) - layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) - layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) - layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) - layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - - if output_layernorm: - layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) - layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) - layer_pd.layernorm.weight.stop_gradient = no_wgrad - layer_pd.layernorm.bias.stop_gradient = no_dbias - layer_te.layernorm.weight.stop_gradient = no_wgrad - layer_te.layernorm.bias.stop_gradient = no_dbias - - def calc_transformer_output_and_grad(layer, encoder_input, mask, dout): - _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) - out = layer(_encoder_input, mask) - out.backward(dout) - return out, _encoder_input.grad - - out_ref, grad_input_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, grad_out - ) - out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - if not no_dbias: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("output_layernorm", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize("recompute_core_attention", [True, False]) -@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) -def test_transformer_decoder_layer( - bs, - hidden_size, - num_heads, - num_gqa_groups, - ffn_hidden_size, - has_bias, - no_dbias, - no_wgrad, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - output_layernorm, - return_layernorm_output, - recompute_core_attention, - normalization, -): - """ - Test Transformer Decoder Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 5e-2 - atol = 6e-2 - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype( - math_dtype - ) - encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype( - math_dtype - ) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - - # rounding to avoid numerical issues - encoder_input = paddle.round(encoder_input * 1000) / 1000 - encoder_output = paddle.round(encoder_output * 1000) / 1000 - grad_out = paddle.round(grad_out * 1000) / 1000 - - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="decoder", - normalization=normalization, - backend="transformer_engine", - ) - layer_pd = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="decoder", - normalization=normalization, - backend="paddle", - ) - - # MultiHeadAttention params - self attn - if output_layernorm: - layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) - layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) - layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.qkv.bias.stop_gradient = no_dbias - else: - layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True - ) - layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True - ) - layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True - ) - layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True - ) - layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - - layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) - layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad - layer_te.self_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) - layer_pd.self_attention.proj.bias.stop_gradient = no_dbias - layer_te.self_attention.proj.bias.stop_gradient = no_dbias - - # MultiHeadAttention params - cross attn - layer_pd.inter_attention.layernorm_query.ln_weight.copy_( - layer_te.inter_attention.layernorm_query.ln_weight, True - ) - layer_pd.inter_attention.layernorm_query.weight.copy_( - layer_te.inter_attention.layernorm_query.weight.T, True - ) - layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad - layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad - layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad - layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.inter_attention.layernorm_query.ln_bias.copy_( - layer_te.inter_attention.layernorm_query.ln_bias, True - ) - layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias - layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.inter_attention.layernorm_query.bias.copy_( - layer_te.inter_attention.layernorm_query.bias, True - ) - layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - - layer_pd.inter_attention.key_value.weight.copy_( - layer_te.inter_attention.key_value.weight.T, True - ) - layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad - layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad - layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True) - layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad - layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True) - layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias - layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias - layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True) - layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias - layer_te.inter_attention.proj.bias.stop_gradient = no_dbias - - # LayerNorm MLP params - layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) - layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) - layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) - layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) - layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) - layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) - layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - - if output_layernorm: - layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) - layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) - layer_pd.layernorm.weight.stop_gradient = no_wgrad - layer_pd.layernorm.bias.stop_gradient = no_dbias - layer_te.layernorm.weight.stop_gradient = no_wgrad - layer_te.layernorm.bias.stop_gradient = no_dbias - - def calc_transformer_output_and_grad( - layer, - encoder_input, - mask, - encoder_output, - enc_dec_attn_mask, - dout, - recompute_core_attention=False, - ): - _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) - _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) - out = layer( - _encoder_input, - mask, - _encoder_output, - enc_dec_attn_mask, - recompute_core_attention=recompute_core_attention, - ) - out.backward(dout) - return out, _encoder_input.grad, _encoder_output.grad - - out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out - ) - out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( - layer_te, - encoder_input, - attn_mask, - encoder_output, - attn_mask, - grad_out, - recompute_core_attention=recompute_core_attention, - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) - assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol) - if not no_wgrad: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - assert_allclose( - layer_te.inter_attention.layernorm_query.weight.grad, - layer_pd.inter_attention.layernorm_query.weight.grad.T, - rtol=rtol, - atol=atol, - ) - if not no_dbias: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.5, - atol=0.6, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - assert_allclose( - layer_te.inter_attention.layernorm_query.bias.grad, - layer_pd.inter_attention.layernorm_query.bias.grad, - rtol=rtol, - atol=atol, - ) - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("bs", [8]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]]) -@pytest.mark.parametrize("mask_type", ["causal"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16"]) -@pytest.mark.parametrize("num_microbatch", [8]) -def test_transformer_encoder_layer_microbatch( - bs, - hidden_size, - num_heads, - ffn_hidden_size, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - num_microbatch, -): - """ - Test Transformer Encoder Layer with FP8 weight caching - """ - paddle.set_default_dtype(math_dtype) - rtol = 1e-5 - atol = 1e-5 - eps = 1e-3 - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - layer_cached = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type="encoder", - ) - layer_normal = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type="encoder", - ) - - layer_normal.self_attention.layernorm_qkv.ln_weight.copy_( - layer_cached.self_attention.layernorm_qkv.ln_weight, True - ) - layer_normal.self_attention.layernorm_qkv.ln_bias.copy_( - layer_cached.self_attention.layernorm_qkv.ln_bias, True - ) - layer_normal.self_attention.layernorm_qkv.weight.copy_( - layer_cached.self_attention.layernorm_qkv.weight, True - ) - layer_normal.self_attention.layernorm_qkv.bias.copy_( - layer_cached.self_attention.layernorm_qkv.bias, True - ) - - layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True) - layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True) - - # LayerNorm MLP params - layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True) - layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True) - layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True) - layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True) - layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True) - layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True) - - recipe = DelayedScaling() - - def generate_input(): - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - return encoder_input, attn_mask, grad_out - - # Calibration to make sure weight scale is the same - encoder_input, mask, _ = generate_input() - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_cached(encoder_input, mask) - - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_normal(encoder_input, mask) - - for iteration in range(num_microbatch): - encoder_input, mask, grad_out = generate_input() - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(encoder_input, mask) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.self_attention.layernorm_qkv.weight.grad, - layer_normal.self_attention.layernorm_qkv.weight.grad, - rtol=rtol, - atol=atol, - ) diff --git a/tests/paddle/test_master_grad.py b/tests/paddle/test_master_grad.py deleted file mode 100644 index c896a7871c..0000000000 --- a/tests/paddle/test_master_grad.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TransformerLayer encoder main_grad""" - -import numpy as np -import pytest - -import paddle -from paddle.distributed.fleet.utils import mix_precision_utils - -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available - -is_fp8_supported, reason = is_fp8_available() - - -def create_optimizer(model, use_pure_bf16, use_main_grad): - """Create optimizer""" - if use_main_grad: - assert use_pure_bf16 - model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.0001, - multi_precision=use_pure_bf16, - ) - if use_main_grad: - optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) - - return optimizer - - -class Net(paddle.nn.Layer): - """Network use for main_grad testing""" - - def __init__(self, fuse_wgrad_accumulation): - super().__init__() - self.layer = te.TransformerLayer( - 4096, - 16384, - 32, - layer_type="encoder", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ) - - def forward(self, inp): - out = self.layer(inp) - return out - - -def train(enable_master_grad, fuse_wgrad_accumulation=False): - """Train function""" - paddle.seed(10) - - accumulate_steps = 4 - - if fuse_wgrad_accumulation: - assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad" - - model = Net(fuse_wgrad_accumulation) - - optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad) - - loss_list = [] - for step_id in range(16): - inp = paddle.uniform([2, 1024, 4096], dtype="float32") - inp.stop_gradient = False - with te.fp8_autocast(enabled=True): - out = model(inp) - loss = out.mean() - loss_list.append(loss) - loss.backward() - - # gradient accumulation - if (step_id + 1) % accumulate_steps == 0: - optimizer.step() - optimizer.clear_grad() - - return loss_list - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -def test_master_grad(): - """Test main_grad""" - paddle.set_default_dtype("float32") - loss1 = train(enable_master_grad=False) - loss2 = train(enable_master_grad=True) - loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True) - - np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py deleted file mode 100644 index d9b1fa5cd1..0000000000 --- a/tests/paddle/test_operators.py +++ /dev/null @@ -1,1201 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE operators""" - -import struct - -import numpy as np -import paddle -import paddle.nn.functional as F -import pytest - -from utils import ( - assert_allclose, - create_fp8_meta, - get_fused_attention_backend, - is_fused_attention_supported, -) - -from transformer_engine import transformer_engine_paddle as tex -from transformer_engine.paddle.cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, - gemm, - fp8_gemm, - transpose, - cast_transpose, - cast_transpose_bgrad, - te_gelu, - gelu_fp8, - swiglu, - swiglu_fp8, - swiglu_pd, - dswiglu, - dgelu_cast_transpose_bgrad_fp8, - layernorm_fwd_fp8, - layernorm_fwd, - layernorm_bwd, - rmsnorm_fwd_fp8, - rmsnorm_fwd, - rmsnorm_bwd, - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, - fused_attn_fwd, - fused_attn_bwd, - scaled_softmax_forward, - scaled_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, - scaled_upper_triang_masked_softmax_backward, -) -from transformer_engine.paddle.fp8 import is_fp8_available -from transformer_engine.paddle.constants import FP8FwdTensors -from transformer_engine.common.recipe import DelayedScaling - -GEMM_CASES = [ - (256, 256, 512), - (32, 32, 32), - (16384, 1024, 2816), - (16384, 2816, 1024), - (16384, 1024, 1024), -] -is_fp8_supported, reason = is_fp8_available() - -SELF_ATTN_CASES = [(2, 512, 12, 64)] -CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)] -FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)] -ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16] - - -@pytest.fixture(autouse=True) -def setup(): - """Setup random seed before each test""" - np.random.seed(10) - paddle.seed(11) - yield - - -@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("inplace", [True, False]) -def test_quantize_dequantize(fp8_dtype, inplace): - """ - Test cast_to_fp8 and cast_from_fp8 - """ - a = paddle.rand(shape=(32, 32), dtype="float32") - # Init fp8_meta - fp8_meta = create_fp8_meta() - a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None - a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8) - b = cast_from_fp8( - a_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_OUTPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - assert_allclose(a, b, rtol=5e-2, atol=5e-2) - - -def copy_bits_from_float_to_uint16(f): - """ - Copy bits - """ - return struct.unpack("> 16 - - -def convert_float_to_uint16(float_list): - """ - convert float to uint16 - """ - new_output = [] - for x in np.nditer(float_list): - new_output.append(np.uint16(copy_bits_from_float_to_uint16(x))) - new_output = np.reshape(new_output, float_list.shape).view(np.uint16) - - return new_output - - -class TestTranspose: - """ - Test transpose operators - """ - - @staticmethod - def test_transpose_bf16(): - """ - Test BF16 transpose - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") - a_transposed = transpose(a, otype=tex.DType.kBFloat16) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_transpose_fp8(fp8_dtype): - """ - Test FP8 transpose - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype) - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - @pytest.mark.parametrize("inplace", [True, False]) - def test_cast_transpose(fp8_dtype, inplace): - """ - Test cast_transpose - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - a_fp8_casted, a_fp8_transposed = None, None - if inplace: - a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8) - a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8) - a_fp8_casted, a_fp8_transposed = cast_transpose( - a, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - otype=fp8_dtype, - cast_out=a_fp8_casted, - transpose_out=a_fp8_transposed, - ) - - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - a_casted = cast_from_fp8( - a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(a_casted, a) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_cast_transpose_bgrad(fp8_dtype): - """ - Test cast_transpose_bgrad - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad( - a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype - ) - - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - a_casted = cast_from_fp8( - a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(a_casted, a) - assert_allclose(a_transposed, a.T) - assert_allclose(bgrad, a.sum(axis=0)) - - -class TestActivation: - """ - Test activation operators - """ - - @staticmethod - def test_gelu_bf16(): - """ - Test BF16 GELU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - gelu_out = te_gelu(a, otype=tex.DType.kBFloat16) - gelu_ref = paddle.nn.GELU()(a) - - assert_allclose(gelu_out, gelu_ref, rtol=1e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_gelu_fp8(fp8_dtype): - """ - Test FP8 GELU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - fp8_meta = create_fp8_meta() - - gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - - gelu_out = cast_from_fp8( - gelu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - gelu_ref = paddle.nn.GELU()(a) - - assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_gelu_bwd_fp8(fp8_dtype): - """ - Test FP8 GELU Backward - """ - # y = GELU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - x.stop_gradient = False - y = paddle.nn.GELU()(x) - y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - paddle.autograd.backward([y], [y_grad], True) - # calculate fp8 - fp8_meta = create_fp8_meta() - x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8( - y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype - ) - - x_grad = cast_from_fp8( - x_grad_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - x_grad_t = cast_from_fp8( - x_grad_t_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) - assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01) - assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01) - - @staticmethod - def test_swiglu_bf16(): - """ - Test BF16 SwiGLU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - swiglu_out = swiglu(a, otype=tex.DType.kBFloat16) - swiglu_ref = swiglu_pd(a) - - assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_swiglu_fp8(fp8_dtype): - """ - Test FP8 SwiGLU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - fp8_meta = create_fp8_meta() - - swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - - swiglu_out = cast_from_fp8( - swiglu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - swiglu_ref = swiglu_pd(a) - - assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01) - - @staticmethod - def test_swiglu_bwd(): - """ - Test SwiGLU Backward - """ - # y = SwiGLU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - x.stop_gradient = False - y = swiglu_pd(x) - y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1 - paddle.autograd.backward([y], [y_grad], True) - # calculate fp8 - x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16) - - assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) - - -class TestGemm: - """ - Tests for gemm(cuBLASLt) operator - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" - ) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_bf16(m, n, k): - """ - Test "TN" BF16 GEMM - """ - a = paddle.rand(shape=(m, k), dtype="bfloat16") - b = paddle.rand(shape=(n, k), dtype="bfloat16") - - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = paddle.matmul(a, b.T) - # CublasLt inside tex.te_gemm assumes inputs are column major. - # Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the - # transpose of X. - # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T, - # which is equivalent to a@b^T = C in row major. - actual_out, _, _ = gemm( - b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False - ) - - assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5) - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" - ) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_bf16_inplace(m, n, k): - """ - Test "TN" BF16 GEMM, with accumulate=True - """ - min_val = -16 - max_val = 16 - a = paddle.rand(shape=(m, k), dtype="bfloat16") - b = paddle.rand(shape=(n, k), dtype="bfloat16") - c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16") - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = c + paddle.matmul(a, b.T) - - actual_out = paddle.clone(c) - _, _, _ = gemm( - b, - a, - paddle.bfloat16, - workspace, - False, - None, - False, - True, - "TN", - actual_out, - None, - False, - ) - - assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_fp8_randint(m, n, k): - """ - Test "TN" FP8 GEMM - """ - min_val = -4 - max_val = 4 - fp8_dtype = tex.DType.kFloat8E4M3 - out_dtype = paddle.float32 - fp8_meta = create_fp8_meta(num_gemms=1) - - a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32") - - a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32") - b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype) - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = paddle.matmul(a, b.T) - actual_out, _ = fp8_gemm( - b_casted, - fp8_meta.scale_inv, - FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, - a_casted, - fp8_meta.scale_inv, - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - out_dtype, - workspace, - ) - - assert_allclose(actual_out, ref_out) - - -class TestLayerNorm: - """ - Test layernorm operators - """ - - @staticmethod - def calc_fwd_ref(x, eps, gamma, beta): - """ - Calculate reference using paddle layer_norm op - """ - y = paddle.nn.functional.layer_norm( - x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps - ) - mean = paddle.mean(x, axis=-1) - var = paddle.var(x, axis=-1) - inv_var = paddle.sqrt(1.0 / var) - return y, mean, inv_var - - @staticmethod - def calc_bwd_ref(x, eps, gamma, beta, dy): - """ - Calculate reference using paddle layer_norm op - """ - x.stop_gradient = False - gamma.stop_gradient = False - beta.stop_gradient = False - - y = paddle.nn.functional.layer_norm( - x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps - ) - - paddle.autograd.backward([y], [dy], True) - - return x.grad, gamma.grad, beta.grad - - def test_layernorm_fwd(self): - """ - Test BF16 LayerNorm Forward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - beta = paddle.uniform(shape=(H,), dtype="bfloat16") - - y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) - - y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta) - - assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4) - assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3) - assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2) - - @staticmethod - def test_layernorm_fwd_fp8(): - """ - Test FP8 LayerNorm Forward - """ - fp8_dtype = tex.DType.kFloat8E4M3 - N, H = (16, 32) - eps = 1e-3 - - x = paddle.uniform(shape=(N, H), dtype="float32") - gamma = paddle.uniform(shape=(H,), dtype="float32") - beta = paddle.uniform(shape=(H,), dtype="float32") - - fp8_tensor = FP8FwdTensors.GEMM1_INPUT - fp8_meta = create_fp8_meta() - - y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32) - - y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype) - - y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32) - - assert_allclose(y, y_ref, rtol=0.1, atol=0.01) - assert_allclose(mu, mu_ref) - assert_allclose(rsigma, rsigma_ref) - - def test_layernorm_bwd(self): - """ - Test BF16 LayerNorm Backward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - dy = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - beta = paddle.uniform(shape=(H,), dtype="bfloat16") - - dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy) - - _, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) - dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma) - - assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5) - assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5) - assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5) - - -class TestRMSNorm: - """ - Test rmsnorm operators - """ - - @staticmethod - def calc_fwd_ref(x, eps, gamma): - """ - Calculate rmsnorm reference using paddle op - """ - - norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps) - y = x * norm * gamma - - return y - - def calc_bwd_ref(self, x, eps, gamma, dy): - """ - Calculate rmsnorm bwd reference using paddle op - """ - x.stop_gradient = False - gamma.stop_gradient = False - - y = self.calc_fwd_ref(x, eps, gamma) - - paddle.autograd.backward([y], [dy], True) - - return x.grad, gamma.grad - - def test_rmsnorm_fwd(self): - """ - Test BF16 RMSNorm Forward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - - y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) - - y_ref = self.calc_fwd_ref(x, eps, gamma) - - assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2) - - @staticmethod - def test_rmsnorm_fwd_fp8(): - """ - Test FP8 RMSNorm Forward - """ - fp8_dtype = tex.DType.kFloat8E4M3 - N, H = (16, 32) - eps = 1e-3 - - x = paddle.uniform(shape=(N, H), dtype="float32") - gamma = paddle.uniform(shape=(H,), dtype="float32") - - fp8_tensor = FP8FwdTensors.GEMM1_INPUT - fp8_meta = create_fp8_meta() - - y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32) - - y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype) - - y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32) - - assert_allclose(y, y_ref, rtol=0.1, atol=0.01) - assert_allclose(rsigma, rsigma_ref) - - def test_rmsnorm_bwd(self): - """ - Test BF16 RMSNorm Backward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - dy = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - - dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy) - - _, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) - dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma) - - assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2) - assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2) - - -class TestFusedAttn: - """ - Test fused attention operators - """ - - def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False): - """ - set test input - """ - - def _random(shape): - if self.dtype == "bfloat16": - data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32") - return convert_float_to_uint16(data) - return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype) - - self.batch_size = b - self.q_seqlen = s_q - self.kv_seqlen = s_kv - self.num_heads = h - self.head_size = d - self.dropout_prob = 0.0 - self.scaling_factor = 1.0 / np.sqrt(d) - self.q_shape = (b, s_q, h, d) - self.kv_shape = (b, s_kv, h, d) - self.fuse_qkv_shape = (b, s_q, 3, h, d) - self.fuse_kv_shape = (b, s_kv, 2, h, d) - self.bias_shape = (1, h, s_q, s_kv) - self.attn_mode = attn_mode - self.dtype = dtype - self.is_causal_masking = is_causal_masking - - self.q = _random(self.q_shape) - if self.attn_mode == "self_attn": - assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen" - self.kv = self.q - else: - self.kv = _random(self.kv_shape) - - self.q_actual_seqlen = None - if self.is_causal_masking: - self.q_actual_seqlen = np.full( - self.batch_size, - self.q_seqlen, - dtype=np.int32, - ) - else: - self.q_actual_seqlen = np.random.randint( - low=20, - high=self.q_seqlen, - size=(self.batch_size,), - dtype=np.int32, - ) - self.kv_actual_seqlen = self.q_actual_seqlen - - self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen) - self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0) - self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen) - self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0) - self.attn_mask = np.ones( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), - dtype=np.int32, - ) - if self.is_causal_masking: - assert attn_mode == "self_attn", "only support causal masking for self attention" - for i in range(0, self.batch_size): - for j in range(self.q_actual_seqlen[i]): - self.attn_mask[i, :, j, : j + 1] = 0 - else: - for i in range(0, self.batch_size): - self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0 - - dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size)) - self.dout = paddle.to_tensor(dout, dtype=self.dtype) - - def _get_reference_out(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - - q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - - qk_out = paddle.matmul( - x=q_out * self.scaling_factor, - y=k_out, - transpose_x=False, - transpose_y=True, - ) - - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool") - attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype) - attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out) - attn_mask_out = paddle.cast(attn_mask_out, "float32") - softmax_out = F.softmax(attn_mask_out) - softmax_out = paddle.cast(softmax_out, self.dtype) - - if self.dropout_prob: - dropout_out = F.dropout( - softmax_out, - self.dropout_prob, - training=self.training, - mode="upscale_in_train", - ) - qkv_out = paddle.matmul(dropout_out, v_out) - else: - qkv_out = paddle.matmul(softmax_out, v_out) - - out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d] - - paddle.autograd.backward( - [out], - [self.dout], - retain_graph=True, - ) - return out, q_tensor.grad, k_tensor.grad, v_tensor.grad - - def _get_fused_attention_out(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - - if self.attn_mode == "self_attn": - qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d] - qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False) - else: - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d] - kv_tensor = paddle.to_tensor(kv, stop_gradient=False) - - q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) - kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - - qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd" - fused_attention_backend = get_fused_attention_backend( - num_heads=self.num_heads, - num_gqa_groups=self.num_heads, - q_seqlen=self.q_seqlen, - kv_seqlen=self.kv_seqlen, - head_size=self.head_size, - dtype=self.dtype, - dropout=self.dropout_prob, - qkv_layout=qkv_layout, - bias_type="no_bias", - mask_type="causal" if self.is_causal_masking else "padding", - ) - - qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 - out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None - if self.attn_mode == "self_attn": - out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked( - qkv_tensor, - q_cu_seqlen_tensor, - is_training=True, - max_seqlen=self.q_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - dqkv, _ = fused_attn_bwd_qkvpacked( - qkv_tensor, - q_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - max_seqlen=self.q_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - q_grad = dqkv[:, :, 0, :, :] - k_grad = dqkv[:, :, 1, :, :] - v_grad = dqkv[:, :, 2, :, :] - else: # attn_mode == 'cross_attn' - out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked( - q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - is_training=True, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - ) - dq, dkv, _ = fused_attn_bwd_kvpacked( - q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - ) - q_grad = dq - k_grad = dkv[:, :, 0, :, :] - v_grad = dkv[:, :, 1, :, :] - - return out, q_grad, k_grad, v_grad - - def _get_fused_attention_with_separate_qkv(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - - q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) - kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - - qkv_layout = "bshd_bshd_bshd" - fused_attention_backend = get_fused_attention_backend( - num_heads=self.num_heads, - num_gqa_groups=self.num_heads, - q_seqlen=self.q_seqlen, - kv_seqlen=self.kv_seqlen, - head_size=self.head_size, - dtype=self.dtype, - dropout=self.dropout_prob, - qkv_layout=qkv_layout, - bias_type="no_bias", - mask_type="causal" if self.is_causal_masking else "padding", - ) - - qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 - out, softmax_aux_tensor, rng_state = fused_attn_fwd( - q_tensor, - k_tensor, - v_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - is_training=True, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - dq, dk, dv, _ = fused_attn_bwd( - q_tensor, - k_tensor, - v_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - - return out, dq, dk, dv - - @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [True, False]) - def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): - """ - test self attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): - """ - test cross attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s_q, - kv_seqlen=s_kv, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bs2hd", - bias_type="no_bias", - mask_type="padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn") - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [True]) - def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): - """ - test flash attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [False, True]) - def test_fused_attn_with_separate_qkv_forward_backward( - self, b, s, h, d, dtype, is_causal_masking - ): - """ - test flash attention forward + backward with separate qkv inputs - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - -class TestSoftmax: - """ - Test softmax operators - """ - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_softmax_fwd_bwd(dtype): - """test scaled softmax""" - B, H, S = (16, 4, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - - y_ref = F.softmax(scale * x) - y = scaled_softmax_forward(x, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_masked_softmax_fwd_bwd(dtype): - """test scaled masked softmax""" - B, H, S = (16, 4, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S)) - mask_flipped = x[0, 0] <= 0.3 - mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4 - - y_ref = F.softmax(scale * x + mask_ref) - y = scaled_masked_softmax_forward(x, mask, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_masked_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype): - """test scaled upper triang masked softmax""" - B, S = (16, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, S, S), dtype=dtype) - - mask = paddle.ones((S, S), dtype="int32") - col_beg, col_end = 1, S - for row in range(0, S): - mask[row, col_beg:col_end] = 0 - col_beg += 1 - - mask_ref = (mask.astype(dtype) - 1.0) * 1e4 - - y_ref = F.softmax(scale * x + mask_ref) - y = scaled_upper_triang_masked_softmax_forward(x, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) - - -@pytest.mark.parametrize("update_weight_scale_inv", [True, False]) -def test_amax_and_scale_update(update_weight_scale_inv): - """Test update_scale""" - num_gemm = 6 - history_len = 1024 - recipe = DelayedScaling() - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_max = recipe.fp8_format.value.max_fwd - non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2)) - - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") - rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) - rolled_history_ref[0] = 0.0 - amax_tensor = paddle.max(amax_history_tensor, axis=0) - scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32") - - def calc_ref(amax, scale, fp8_max, margin=0): - """Calculate reference scale""" - sf = (fp8_max / amax) / (2**margin) - sf = paddle.where(amax > 0.0, sf, scale) - sf = paddle.where(paddle.isfinite(amax), sf, scale) - return sf - - scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0) - if update_weight_scale_inv: - scale_inv_ref = 1.0 / scale_ref - else: - scale_inv_ref = paddle.zeros_like(scale_tensor) - scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref) - - # Placeholder - scale_actual = paddle.zeros_like(scale_tensor) - scale_inv_actual = paddle.zeros_like(scale_tensor) - - if update_weight_scale_inv: - non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace( - _amax_history=amax_history_tensor, - _scale=scale_actual, - _scale_inv=scale_inv_actual, - non_weight_mask=non_weight_mask, - fp8_dtype=int(fp8_dtype), - margin=0.0, - amax_compute="max", - ) - - assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7) - assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7) - assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7) - - -def test_update_latest_history(): - """Test update_latest_history""" - num_gemm = 6 - history_len = 1024 - - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") - amax = paddle.rand(shape=[num_gemm], dtype="float32") - - tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax) - - assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7) diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py deleted file mode 100644 index 82f970b2c8..0000000000 --- a/tests/paddle/test_parallel.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Parallel""" - -from pathlib import Path -import unittest - -from dist_launcher import TestDistributed -from utils import is_devices_enough - -from transformer_engine.paddle.fp8 import is_fp8_available - -test_root = Path(__file__).resolve().parent -gpu_has_fp8, reason = is_fp8_available() - - -class TestParallelLinear(TestDistributed): - """Test Linear in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_linear_tp(self): - """Tests linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py")) - - -class TestParallelLayerNormLinear(TestDistributed): - """Test LayerNormLinear in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_layernorm_linear_tp(self): - """Tests layernorm_linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py")) - - -class TestParallelLayerNormMLP(TestDistributed): - """Test LayerNormMLP in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_layernorm_mlp_tp(self): - """Tests layernorm_mlp with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py")) - - -class TestAmaxReduction(TestDistributed): - """Test amax reduction in dp mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_amax_reduction(self): - """Tests amax reduction""" - self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py")) - - -class TestPipelineParallel(TestDistributed): - """Test pipeline parallel""" - - @unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_pipeline_parallel(self): - """Tests pipeline parallel""" - self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py")) - - -class TestGroupSharding(TestDistributed): - """Test group sharding""" - - @unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_group_sharding(self): - """Tests group sharding""" - self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py")) - - -class TestParallelAttention(TestDistributed): - """Test MultiHeadAttention Layer in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelAttention needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_attention_tp(self): - """Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py")) - - -class TestParallelTransformerLayer(TestDistributed): - """Test Transformer Layer in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_transformer_tp(self): - """Tests Transformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py")) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/test_recompute.py b/tests/paddle/test_recompute.py deleted file mode 100644 index 59079b0d1d..0000000000 --- a/tests/paddle/test_recompute.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Recompute""" - -from pathlib import Path -import re -import subprocess - -import numpy as np -import pytest - -from transformer_engine.paddle.fp8 import is_fp8_available - -test_root = Path(__file__).resolve().parent -is_fp8_supported, reason = is_fp8_available() - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("use_reentrant", [False, True]) -def test_transformer_encoder_recompute(use_reentrant): - """ - Test TransformerLayer encoder recompute - """ - rtol = 1e-5 - atol = 1e-5 - - def launch_subprocess_and_check_output(enable_recompute): - """Launch training in subprocess and check output""" - try: - cmd = [ - "python", - str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"), - str(int(enable_recompute)), - str(int(use_reentrant)), - ] - result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) - - print(result) - - loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result) - memory_match = re.search(r"Peak memory:\s+(\d+)", result) - - loss_value = float(loss_match.group(1)) - memory_value = int(memory_match.group(1)) - - return loss_value, memory_value - - except subprocess.CalledProcessError as e: - raise ValueError(f"Subprocess failed with error: {e}") from e - - loss_recompute, peak_memory_recompute = launch_subprocess_and_check_output(True) - loss_ref, peak_memory_ref = launch_subprocess_and_check_output(False) - - assert peak_memory_recompute < peak_memory_ref - np.testing.assert_allclose(loss_recompute, loss_ref, rtol=rtol, atol=atol) diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py deleted file mode 100644 index b0a8d0d80b..0000000000 --- a/tests/paddle/utils.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utils for testing""" - -import random -from typing import Union - -import numpy as np -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker - -import transformer_engine # pylint: disable=unused-import -from transformer_engine.paddle.constants import ( - TE_DType, - AttnBiasType, - AttnMaskType, - FusedAttnBackend, -) -from transformer_engine.paddle.fp8 import FP8TensorMeta -from transformer_engine import ( - transformer_engine_paddle as tex, -) # pylint: disable=wrong-import-order - - -def create_fp8_meta(num_gemms=1, amax_history_len=10): - """ - Create and initialize FP8TensorMeta - """ - fp8_meta = FP8TensorMeta(is_forward=True) - fp8_meta.prepare(num_gemms, amax_history_len) - return fp8_meta - - -def assert_allclose( - actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True -): - """Compare two input paddle tensors""" - if isinstance(actual, paddle.Tensor): - actual = paddle.cast(actual, "float32") - if isinstance(desired, paddle.Tensor): - desired = paddle.cast(desired, "float32") - if len(actual.shape) == 0: - actual = actual.item() - desired = desired.item() - else: - actual = actual.numpy() - desired = desired.numpy() - np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose) - - -def assert_shape(inp, expected_shape): - """Assert the shape of input tensor equals to expected shape""" - assert ( - inp.shape == expected_shape - ), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}" - - -def is_devices_enough(required): - """If the number of device is enough""" - return paddle.device.cuda.device_count() >= required - - -def set_random_seed(seed): - """Set random seed for reproducability.""" - fleet.meta_parallel.model_parallel_random_seed(seed) - - hcg = fleet.get_hybrid_communicate_group() - if paddle.distributed.get_world_size() > 1: - # obtain rank message of hybrid parallel - - mp_rank = hcg.get_model_parallel_rank() - mp_size = hcg.get_model_parallel_world_size() - - pp_rank = hcg.get_stage_id() - pp_size = hcg.get_pipe_parallel_world_size() - - dp_rank = hcg.get_data_parallel_rank() - dp_size = hcg.get_data_parallel_world_size() - - sharding_rank = hcg.get_sharding_parallel_rank() - else: - mp_rank, mp_size = 0, 1 - pp_rank, pp_size = 0, 1 - dp_rank, dp_size = 0, 1 - sharding_rank, _ = 0, 1 - - random.seed(seed + 100 * pp_rank) - np.random.seed(seed + 100 * pp_rank) - - seed_offset = seed + 1024 + paddle.distributed.get_world_size() - global_seed = ( - seed_offset - + pp_rank * (mp_size) - + dp_rank * (mp_size * pp_size) - + sharding_rank * (mp_size * pp_size * dp_size) - ) - - seed_offset += paddle.distributed.get_world_size() - local_seed = ( - seed_offset - + mp_rank - + pp_rank * (mp_size) - + dp_rank * (mp_size * pp_size) - + sharding_rank * (mp_size * pp_size * dp_size) - ) - - tracker = get_rng_state_tracker() - # tracker.reset() - if "global_seed" not in tracker.states_: - tracker.add("global_seed", global_seed) - if "local_seed" not in tracker.states_: - tracker.add("local_seed", local_seed) - - paddle.seed(global_seed) - - -def get_fused_attention_backend( - num_heads: int, - num_gqa_groups: int, - q_seqlen: int, - kv_seqlen: int, - head_size: int, - dtype: Union[paddle.dtype, str], - dropout: float, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - mask_type: str = "causal", -) -> tex.NVTE_Fused_Attn_Backend: - """Get cuDNN fused attention backend for attention config""" - if isinstance(dtype, str): - dtype = dict( - float32=paddle.float32, - bfloat16=paddle.bfloat16, - float16=paddle.float16, - )[dtype] - return tex.get_fused_attn_backend( - TE_DType[dtype], - TE_DType[dtype], - tex.get_nvte_qkv_layout(qkv_layout), - AttnBiasType[bias_type], - AttnMaskType[mask_type], - dropout, - num_heads, - num_gqa_groups, - q_seqlen, - kv_seqlen, - head_size, - ) - - -def is_fused_attention_supported( - num_heads: int, - num_gqa_groups: int, - q_seqlen: int, - kv_seqlen: int, - head_size: int, - dtype: Union[paddle.dtype, str], - dropout: float, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - mask_type: str = "causal", -) -> bool: - """Check if cuDNN fused attention is supported for attention config""" - backend = get_fused_attention_backend( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=dtype, - dropout=dropout, - qkv_layout=qkv_layout, - bias_type=bias_type, - mask_type=mask_type, - ) - return backend != FusedAttnBackend["No_Backend"] - - -def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None: - """Register allreduce hooks for sequence parallel tensors""" - - def is_sequence_parallel_parameter(parameter): - """If input tensor is marked as sequence parallel tensor""" - out = getattr(parameter, "sequence_parallel", False) - return out - - def create_allreduce_gradient_hook(param, accumulation_steps): - """Create allreduce gradient hook""" - hcg = fleet.get_hybrid_communicate_group() - pg = hcg.get_model_parallel_group().process_group - step = [0] - - @paddle.autograd.no_grad() - def __impl__(): - step[0] += 1 - if (step[0] % accumulation_steps) == 0: - if hasattr(param, "main_grad"): - pg.allreduce(param.main_grad).wait() - else: - pg.allreduce(param.grad).wait() - - return __impl__ - - if accumulation_steps <= 0 or not paddle.distributed.is_initialized(): - return - - hcg = fleet.get_hybrid_communicate_group() - mp_group = hcg.get_model_parallel_group() - if mp_group.nranks <= 1: - return - - params = [] - for p in model.parameters(): - if is_sequence_parallel_parameter(p): - params.append(p) - - for p in params: - hook = create_allreduce_gradient_hook(p, accumulation_steps) - p._register_backward_hook(hook) diff --git a/tests/pytorch/custom_ort_ops/.gitignore b/tests/pytorch/custom_ort_ops/.gitignore deleted file mode 100644 index d491fb774c..0000000000 --- a/tests/pytorch/custom_ort_ops/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build -onnxruntime -libcustom_ort_ops.so diff --git a/tests/pytorch/custom_ort_ops/CMakeLists.txt b/tests/pytorch/custom_ort_ops/CMakeLists.txt deleted file mode 100644 index d3e95bd4bc..0000000000 --- a/tests/pytorch/custom_ort_ops/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -cmake_minimum_required(VERSION 3.21) -project(custom_ort_ops LANGUAGES CXX) - -# Dependencies -find_package(CUDAToolkit REQUIRED) -set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include) -if(NOT EXISTS "${ONNX_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find ONNX Runtime headers. " - "Please clone https://github.com/microsoft/onnxruntime " - "into TransformerEngine/tests/pytorch/onnx.") -endif() -include_directories(${ONNX_INCLUDE_DIR}) - -# Configure library -add_library(custom_ort_ops SHARED custom_op_library.cc) -target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart) -target_include_directories(custom_ort_ops PUBLIC - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(custom_ort_ops PRIVATE - ${ONNX_INCLUDE_DIR}/onnxruntime - ${ONNX_INCLUDE_DIR}/onnxruntime/core/session) - -# Install library -install(TARGETS custom_ort_ops DESTINATION .) diff --git a/tests/pytorch/custom_ort_ops/README.md b/tests/pytorch/custom_ort_ops/README.md deleted file mode 100644 index ca392805be..0000000000 --- a/tests/pytorch/custom_ort_ops/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Custom ONNX Runtime operators for Transformer Engine tests - -This directory contains code that builds custom ONNX operators for use -in Transformer Engine tests. It includes basic, non-performant -implementations of the FP8 quantization and dequantization operators -that are used when exporting Transformer Engine models to ONNX. - -For more information, see [the ONNX Runtime reference for custom -operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html). -Much of the code has been adapted from [an ONNX Runtime -test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). - -## Usage - -* Build the custom operators: -```bash -$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh -``` -* Run the ONNX export tests with pytest: -```bash -$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py -``` \ No newline at end of file diff --git a/tests/pytorch/custom_ort_ops/build.sh b/tests/pytorch/custom_ort_ops/build.sh deleted file mode 100644 index 01502ba6fb..0000000000 --- a/tests/pytorch/custom_ort_ops/build.sh +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -ex - -: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))} -cd ${CUSTOM_ORT_OPS_PATH} - -# Download ONNX Runtime source -git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true - -# Configure and build with CMake -mkdir -p build -cmake -S . -B build -DCMAKE_INSTALL_PREFIX=. -cmake --build build --verbose -cmake --install build --verbose diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.cc b/tests/pytorch/custom_ort_ops/custom_op_library.cc deleted file mode 100755 index c7b94ff700..0000000000 --- a/tests/pytorch/custom_ort_ops/custom_op_library.cc +++ /dev/null @@ -1,102 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "custom_op_library.h" - -#define ORT_API_MANUAL_INIT -#include "onnxruntime_c_api.h" -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/session/onnxruntime_lite_custom_op.h" -#include - -namespace { - -template -void Quantize(OrtKernelContext* context, - const Ort::Custom::Tensor& input, - const Ort::Custom::Tensor& scale_inv, - Ort::Custom::Tensor& output) { - auto raw_input = input.Data(); - auto raw_scale_inv = scale_inv.Data(); - auto raw_output = reinterpret_cast(output.Allocate(input.Shape())); - const auto rs = static_cast(raw_scale_inv[0]); - const size_t N = input.NumberOfElement(); - for (size_t i = 0; i < N; ++i) { - const auto x = static_cast(raw_input[i]); - raw_output[i] = static_cast(x / rs); - } -} - -template -void Dequantize(OrtKernelContext* context, - const Ort::Custom::Tensor& input, - const Ort::Custom::Tensor& scale_inv, - Ort::Custom::Tensor& output) { - auto raw_input = reinterpret_cast(input.Data()); - auto raw_scale_inv = scale_inv.Data(); - auto raw_output = output.Allocate(input.Shape()); - const auto rs = static_cast(raw_scale_inv[0]); - const size_t N = input.NumberOfElement(); - for (size_t i = 0; i < N; ++i) { - const auto x = rs * static_cast(raw_input[i]); - raw_output[i] = static_cast(x); - } -} - -static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { - static std::vector ort_custom_op_domain_container; - static std::mutex ort_custom_op_domain_mutex; - std::lock_guard lock(ort_custom_op_domain_mutex); - ort_custom_op_domain_container.push_back(std::move(domain)); -} - -} // namespace - -OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { - Ort::Global::api_ = api->GetApi(ORT_API_VERSION); - - // Namespace for custom ops - static const char* c_OpDomain = "trt"; - - // Construct custom ops - static const std::unique_ptr c_Quantize{ - Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear", - "CPUExecutionProvider", - Quantize) - }; - static const std::unique_ptr c_Dequantize{ - Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear", - "CPUExecutionProvider", - Dequantize<__nv_fp8_e4m3, float, float>) - }; - - // Register custom ops - OrtStatus* result = nullptr; - ORT_TRY { - Ort::CustomOpDomain domain{c_OpDomain}; - domain.Add(c_Quantize.get()); - domain.Add(c_Dequantize.get()); - Ort::UnownedSessionOptions session_options(options); - session_options.Add(domain); - AddOrtCustomOpDomainToContainer(std::move(domain)); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - Ort::Status status{e}; - result = status.release(); - }); - } - return result; -} diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 4f170e3f84..9e11e07e11 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -19,8 +19,8 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.common.recipe import Format -from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) @@ -47,14 +47,14 @@ def _mapped_argtype(opt, typemap): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=512, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." ) - parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) @@ -288,33 +288,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else tex.CommOverlapHelper(bootstrap_pg) ) - if opts.comm_type == tex.CommOverlapType.RS: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS - elif opts.p2p: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - ) - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ) - elif opts.comm_type == tex.CommOverlapType.AG: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - ) - else: - raise TypeError("Invalid comm+GEMM overlap type!") - # Initialize userbuffers with (M, N) buffer # M = sequence * batch # N = hidden size @@ -322,11 +295,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) buffer_dtype = torch.bfloat16 - if ( - opts.fp8 - and not opts.bulk_overlap - and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) - ): + if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG: buffer_dtype = torch.uint8 ub_obj = ( tex.CommOverlapP2P( @@ -421,6 +390,10 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) + # Allocate cuBLAS workspace + workspace_size = 3 * get_cublas_workspace_size_bytes() + workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") + # Gather global tensors and calculate reference result (need these first for Fp8 scales) if opts.bulk_overlap: ker_g = torch.transpose(kernel_t, 0, 1) @@ -467,120 +440,123 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) + inp_quantizer = None + ker_quantizer = None + out_quantizer = None + bulk_inp_quantizer = None + inp2_quantizer = None + ker2_quantizer = None + out2_quantizer = None if opts.fp8: - fp8_formats = { - tex.DType.kFloat8E4M3: Format.E4M3, - tex.DType.kFloat8E5M2: Format.E5M2, - } - # Structure to maintain amax and scale/scale_inv information for the kernel and input - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_meta = tex.FP8TensorMeta() num_gemms = 6 if ub_obj2 is not None else 3 - fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda") - fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda") - fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_dtype = tex.DType.kFloat8E4M3 + fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda") # Compute initial amaxes and scales inp_amax = torch.max(torch.abs(inp_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_INPUT].copy_(inp_amax) + fp8_amaxes[0].copy_(inp_amax) ker_amax = torch.max(torch.abs(ker_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) + fp8_amaxes[1].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) + fp8_amaxes[2].copy_(ref_amax) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) + fp8_amaxes[5].copy_(bulk_amax) elif ub_obj2 is not None: inp2_amax = torch.max(torch.abs(inp2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) + fp8_amaxes[3].copy_(inp2_amax) ker2_amax = torch.max(torch.abs(ker2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_WEIGHT].copy_(ker2_amax) + fp8_amaxes[4].copy_(ker2_amax) ref2_amax = torch.max(torch.abs(ref2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(ref2_amax) - fp8_meta.scale = _default_sf_compute( - fp8_meta.amax_history[1], fp8_meta.scale, fp8_formats[fp8_dtype].value.max_fwd, 1 - ) - fp8_meta.scale_inv = torch.reciprocal(fp8_meta.scale) + fp8_amaxes[5].copy_(ref2_amax) - # Cast input to Float8Tensor - inp_fp8 = tex.cast_to_fp8(inp, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype) + inp_quantizer = Float8Quantizer(fp8_scales[0].clone(), fp8_amaxes[0].clone(), fp8_dtype) + ker_quantizer = Float8Quantizer(fp8_scales[1].clone(), fp8_amaxes[1].clone(), fp8_dtype) + if opts.fp8_output: + out_quantizer = Float8Quantizer(fp8_scales[2].clone(), fp8_amaxes[2].clone(), fp8_dtype) - # Cast kernel to Float8Tensor - kernel_t_fp8 = tex.cast_to_fp8( - kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype - ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: - bulk_inp_fp8 = tex.cast_to_fp8( - bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype + bulk_inp_quantizer = Float8Quantizer( + fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype ) elif ub_obj2 is not None: - kernel2_t_fp8 = tex.cast_to_fp8( - kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype + inp2_quantizer = Float8Quantizer( + fp8_scales[3].clone(), fp8_amaxes[3].clone(), fp8_dtype + ) + ker2_quantizer = Float8Quantizer( + fp8_scales[4].clone(), fp8_amaxes[4].clone(), fp8_dtype ) + if opts.fp8_output: + out2_quantizer = Float8Quantizer( + fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype + ) + + # Cast input to Float8Tensor + inp_fp8 = inp_quantizer(inp) + + # Cast kernel to Float8Tensor + kernel_t_fp8 = ker_quantizer(kernel_t) + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: + bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) + elif ub_obj2 is not None: + kernel2_t_fp8 = ker2_quantizer(kernel2_t) # Make sure the inputs are cast correctly if opts.check_numerics: torch.allclose( inp.to(dtype=torch.float32), - inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) torch.allclose( kernel_t.to(dtype=torch.float32), - kernel_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + kernel_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), - bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], + bulk_inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) elif ub_obj2 is not None: torch.allclose( kernel2_t.to(dtype=torch.float32), - kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], + kernel2_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) - # Set Fp8 scales for userbuffers - if opts.comm_type == tex.CommOverlapType.AG: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) - if ub_obj2 is not None: - ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - elif opts.bulk_overlap: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - else: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) - # Set up comm/compute buffers - ubuf_out2 = None + rs_out = None rs_out2 = None if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp, 1) + ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) gemm_inp = inp else: - ub_obj.copy_input_to_ubuf(inp_fp8 if opts.fp8 else inp, 1) - gemm_inp = ub_obj.get_ubuf_output(1) - ubuf_out = None - rs_out = None + ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True) + gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) if ub_obj2 is not None: - ubuf_out2 = ub_obj2.get_ubuf_output(1) + if opts.fp8 and opts.fp8_output: + ub_obj2.set_buffer_params(out_quantizer) rs_out2 = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" ) else: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0) - ubuf_out = None - else: - ubuf_out = ub_obj.get_ubuf_output(1) + ub_obj.copy_into_buffer( + bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False + ) + if opts.fp8: + ub_obj.set_buffer_params(bulk_inp_quantizer) + elif opts.fp8 and opts.fp8_output: + ub_obj.set_buffer_params(out_quantizer) gemm_inp = inp_fp8 if opts.fp8 else inp rs_out = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" @@ -588,88 +564,47 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use def _fp8_gemm(): - return tex.fp8_gemm( + return tex.general_gemm( kernel_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, gemm_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=ub_algo, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap, ) def _fp8_gemm2(gemm1_out): gemm2_inp = tex.gelu( - ( - tex.cast_from_fp8( - gemm1_out, - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) - if opts.fp8_output - else gemm1_out - ), - fp8_meta, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, + (gemm1_out.dequantize() if opts.fp8_output else gemm1_out), + inp2_quantizer, ) - return tex.fp8_gemm( + return tex.general_gemm( kernel2_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype, gemm2_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out2_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic_rs_p2p - else tex.CommOverlapAlgo.ATOMIC_GEMM_RS - ), ub=ub_obj2, - extra_output_tensor=rs_out2, - out=ubuf_out2, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None, + ub_type=tex.CommOverlapType.AG, + extra_output=rs_out2, ) def _gemm(): - return tex.gemm( + return tex.general_gemm( kernel_t, gemm_inp, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - ub_algo=ub_algo, + workspace, + out_dtype=torch.bfloat16, + use_split_accumulator=te.module.base._2X_ACC_FPROP, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap, ) # Trigger GEMM @@ -746,10 +681,10 @@ def _gemm(): output_info = "" if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered - test_out = ub_obj.get_ubuf_output(1) + test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) else: # Bulk overlap RS output needs to be gathered - out_local = ub_obj.get_ubuf_output(0) + out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) output_info += f"rs_output: {list(out_local.shape)} | " test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] @@ -775,17 +710,7 @@ def _gemm(): test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] else: # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) - output = ( - tex.cast_from_fp8( - all_outputs[0], - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) - if opts.fp8_output - else all_outputs[0] - ) + output = all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0] test_out = torch.transpose( te.distributed.gather_along_first_dim( torch.transpose(output, 0, 1), tp_group @@ -798,25 +723,6 @@ def _gemm(): output = rs_out.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] - if opts.fp8: - dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" - + f"scale = {fp8_meta.scale[:3].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - if ub_obj2 is not None: - dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" - + f"scale = {fp8_meta.scale[3:].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - ref_out = ref2_g if ub_obj2 is not None else ref_g test_nonzeros = torch.count_nonzero(test_out) ref_nonzeros = torch.count_nonzero(ref_out) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 5a67bd616a..d4a01386ee 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -9,6 +9,7 @@ import socket import argparse import warnings +import pprint import torch import torch.distributed as dist @@ -39,6 +40,8 @@ def _te_layer_argtype(name): def _get_layer_args(config, tp_group, tp_size, reference=False): hidden_size = config.num_heads * config.head_dim + ffn_hidden_size = 4 * hidden_size + qkv_size = 3 * hidden_size input_shape = [config.seq_length, config.batch_size, hidden_size] args = [hidden_size] kwargs = { @@ -47,46 +50,41 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): "tp_group": tp_group, "tp_size": tp_size, "sequence_parallel": True, + "ub_overlap_ag": not reference, + "ub_overlap_rs": not reference, } - kwargs["ub_overlap_ag"] = not reference - if config.layer_type is te.Linear: + if config.layer_type in [te.Linear, te.LayerNormLinear]: if config.linear_parallel_mode == "row": - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["ub_overlap_rs"] = not reference + input_shape[-1] = ffn_hidden_size // tp_size + args = [ffn_hidden_size, hidden_size] + kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2" elif config.linear_parallel_mode == "column": input_shape[0] = config.seq_length // tp_size - args.append(3 * hidden_size) - kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference + args.append(qkv_size) + kwargs["ub_name"] = "qkv" + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference kwargs["parallel_mode"] = config.linear_parallel_mode - kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv" else: input_shape[0] = config.seq_length // tp_size + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(ffn_hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not reference + kwargs["hidden_dropout"] = 0.0 + kwargs["set_parallel_mode"] = True kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference - kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference - if config.layer_type is te.LayerNormLinear: - args.append(3 * hidden_size) - kwargs["parallel_mode"] = "column" - kwargs["ub_name"] = "qkv" - else: - kwargs["set_parallel_mode"] = True - kwargs["ub_overlap_rs"] = not reference - if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: - args.append(4 * hidden_size) - kwargs["seq_length"] = config.seq_length - if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: - args.append(config.num_heads) - kwargs["attention_dropout"] = 0.0 - kwargs["fuse_qkv_params"] = True - if config.layer_type is te.MultiheadAttention: - kwargs["input_layernorm"] = True - else: - kwargs["ub_tp_comm_overlap"] = not reference - kwargs["hidden_dropout"] = 0.0 + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference return args, kwargs, input_shape @@ -97,12 +95,12 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP) parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( @@ -144,7 +142,7 @@ def _parse_args(argv=None, namespace=None): "--overlap-rs-dgrad", action="store_true", default=False, - help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.", + help="Replace bulk DGRAD/WGRAD overlaps with DGRAD+RS in the backward pass for AG+GEMM.", ) parser.add_argument( "--debug", @@ -175,7 +173,7 @@ def _compare_tensors(name, test, ref, rtol, atol): ) return 1, numerics_info - diff = torch.abs(test - ref).flatten() + diff = torch.abs(test.flatten() - ref.flatten()) m = torch.argmax(diff) abs_err = diff[m].item() rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5) @@ -254,8 +252,8 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ub_cfgs = None if opts.overlap_rs_dgrad: ub_cfgs = { - "proj_dgrad": {"method": "ring_exchange"}, "qkv_dgrad": {"method": "ring_exchange"}, + "fc1_dgrad": {"method": "ring_exchange"}, } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], @@ -271,6 +269,10 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): with te.fp8_model_init(enabled=opts.fp8_init): test_model = opts.layer_type(*args, **kwargs) dist_print("Initialized test model...", debug=True) + if WORLD_RANK == 0: + pprint.pprint(kwargs) + sys.stdout.write("\n") + dist.barrier() # Initialize the reference model and copy all parameters ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) @@ -305,8 +307,8 @@ def run_fwd_bwd(model, x): out, *_ = y else: out = y - loss = out.sum() - loss.backward() + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() @@ -342,29 +344,27 @@ def run_fwd_bwd(model, x): ref_grads.append(ref_param.grad) # Make sure we have the same number of gradients - num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") if len(test_grads) != len(ref_grads): - num_grads_failed[0] = 1 + numerics_failed[0] = 1 numerics_info = ( "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + f"expected {len(ref_grads)} but got {len(test_grads)}." ) dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) # Now validate accuracy - numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda") - if not bool(num_grads_failed.item()): + if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): rtol = 0.125 if opts.fp8 else 0.025 atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[i] = int(grad_failed) - return_code = torch.max(numerics_failed) - dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world) - else: - return_code = num_grads_failed + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + if bool(numerics_failed.item()) and not opts.debug: + break te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) @@ -374,7 +374,7 @@ def run_fwd_bwd(model, x): if opts.debug and WORLD_RANK == 0: print("Exiting...\n", end="", flush=True) - return return_code.item() + return numerics_failed[0].item() if __name__ == "__main__": diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 64f36051c6..2d301e3151 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -4,9 +4,10 @@ # # See LICENSE for license information. -import sys -import os import argparse +import datetime +import os +import sys from functools import wraps import transformer_engine.pytorch as te @@ -14,7 +15,12 @@ from torch import nn import torch.distributed as dist -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, + DelayedScaling, + Format, + Recipe, +) from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -23,15 +29,27 @@ WORLD_RANK, WORLD_SIZE = None, None NCCL_WORLD = None LOSS_FN = nn.MSELoss() -FP8 = False +QUANTIZATION = None + + +# Disable TF32 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False -# Fp8 recipe setup -fp8_format = Format.HYBRID -fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "fp8": + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + if QUANTIZATION == "mxfp8": + return MXFP8BlockScaling() + return te.fp8.get_default_fp8_recipe() def main(argv=None, namespace=None): - global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8 + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) @@ -44,6 +62,7 @@ def main(argv=None, namespace=None): "backend": "nccl", "rank": WORLD_RANK, "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), } dist_init_kwargs["init_method"] = "env://" dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") @@ -57,9 +76,17 @@ def main(argv=None, namespace=None): parser = argparse.ArgumentParser() parser.add_argument("-l", "--layer-type", type=str) - parser.add_argument("--fp8", action="store_true", default=False) + parser.add_argument("--quantization", type=str, default=None) args = parser.parse_args(argv, namespace) + # Quantization scheme + QUANTIZATION = args.quantization + if QUANTIZATION in ("fp8", "mxfp8"): + global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE + SEQ_LEN = 32 + BATCH_SIZE = 32 + HIDDEN_SIZE = 128 + test_dict = [ test_linear, test_layernorm, @@ -68,8 +95,6 @@ def main(argv=None, namespace=None): test_transformer_layer, ] - FP8 = args.fp8 - for test in test_dict: test() dist.destroy_process_group() @@ -124,11 +149,10 @@ def dist_print(msg, src=None, end="\n", error=False): stream = sys.stderr if error else sys.stdout if WORLD_RANK == (0 if src is None else src): stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") - dist.barrier() def _get_tolerances(dtype): - if FP8: + if QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} if dtype == torch.float16: @@ -153,8 +177,7 @@ def _check_outputs(output_single_node, output_distributed): dist_print(output_info, src=WORLD_RANK, error=output_failed) numerics_failed[0] = int(output_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) - if bool(numerics_failed.item()): - sys.exit(1) + assert not bool(numerics_failed.item()) def _match_param_sizes(dist_param, single_param): @@ -213,13 +236,12 @@ def _check_gradients(model_distributed, model_single, main_grad_check=False): ) if grad_failed: - dist_print(i) - dist_print(name) + dist_print(i, src=WORLD_RANK) + dist_print(name, src=WORLD_RANK) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) - if bool(numerics_failed.item()): - sys.exit(1) + assert not bool(numerics_failed.item()) def _copy_params(model_distributed, model_single): @@ -243,9 +265,18 @@ def _apply_models( model_single_node, model_distributed, input_single_node, input_distributed, **kwargs ): _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True - with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe): + input_single_node.requires_grad_() + input_distributed.requires_grad_() + with te.fp8_autocast( + enabled=QUANTIZATION is not None, + fp8_recipe=quantization_recipe(), + ): output_single_node = model_single_node(input_single_node, **kwargs) - with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe, fp8_group=NCCL_WORLD): + with te.fp8_autocast( + enabled=QUANTIZATION is not None, + fp8_recipe=quantization_recipe(), + fp8_group=NCCL_WORLD, + ): output_distributed = model_distributed(input_distributed, **kwargs) return output_single_node, output_distributed @@ -544,9 +575,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg """ # Set parameter data type params_dtype = kwargs.get("params_dtype", torch.float32) - FFN_HIDDEN_SIZE = ( - 64 if FP8 else 32 - ) # larger tensors lead to numerical failures with thight atol and rtol + FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128 # Create models model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs) @@ -636,9 +665,7 @@ def test_layernorm_mlp(): @run_distributed_test() def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs): params_dtype = kwargs.get("params_dtype", torch.float32) - FFN_HIDDEN_SIZE = ( - 64 if FP8 else 32 - ) # larger tensors lead to numerical failures with thight atol and rtol + FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128 model_single_node = te.TransformerLayer( HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index c285da7fbd..52420efca5 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -16,23 +16,22 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -RNG_SEED: int = 1234 -SEQ_LENGTH: int = 512 +RNG_SEED: int = 42 +SEQ_LENGTH: int = 1024 BATCH_SIZE: int = 2 -NUM_HEADS: int = 12 -HEAD_DIM: int = 64 - -# NOTE: te.Linear is intentionally omitted here and manually added later for testing both -# row and column parallel layouts. +NUM_HEADS: int = 16 +HEAD_DIM: int = 48 TE_LAYERS = [ + te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, te.TransformerLayer, ] +MAX_LAYER_NAME_LENGTH = max([len(layer.__name__) for layer in TE_LAYERS]) TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS: int = min(torch.cuda.device_count(), 4) +NUM_PROCS: int = torch.cuda.device_count() LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] if tex.ubuf_built_with_mpi(): LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"] @@ -48,7 +47,7 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -64,19 +63,15 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg if bulk: test_cmd.append("--bulk-overlap") else: - if fp8_in: + if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_out: - test_cmd.append("--fp8-output") if p2p: test_cmd.append("--p2p") - if aggregate: - test_cmd.append("--aggregate") if atomic: - if torch.cuda.get_device_properties(0).major < 9: - pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") test_cmd.append("--atomic") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) @@ -88,7 +83,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -99,15 +94,16 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] - if layer_type == te.Linear.__name__: + if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]: test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") + if overlap_rs_dgrad: + test_cmd.append("--overlap-rs-dgrad") + if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_init: - test_cmd.append("--fp8-init") os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" @@ -128,88 +124,39 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): @pytest.mark.parametrize( - "fp8,aggregate", - [ - (False, False), - (False, True), - (True, False), - (True, True), - ], - ids=[ - " BF16 IN - RING-EXCHANGE ", - " BF16 IN - RING-EXCHANGE - 2x AGGREGATED ", - " FP8 IN - RING-EXCHANGE ", - " FP8 IN - RING-EXCHANGE - 2x AGGREGATED ", - ], + "fp8", + (False, True), + ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "], ) -def test_split_all_gather_overlaps(fp8, aggregate): +def test_split_all_gather_overlaps(fp8): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate) + _run_gemm_with_overlap("AG", False, True, False, fp8) @pytest.mark.parametrize( - "fp8_in,fp8_out,p2p", + "fp8,p2p", [ - (False, False, False), - (False, False, True), - (True, False, False), - (True, False, True), - (True, True, False), - (True, True, True), + (False, False), + (False, True), + (True, False), + (True, True), ], ids=[ - " BF16 IN - BF16 OUT - PIPELINE ", - " BF16 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - BF16 OUT - PIPELINE ", - " FP8 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - FP8 OUT - PIPELINE ", - " FP8 IN - FP8 OUT - RING-EXCHANGE ", + " BF16 - PIPELINE ", + " BF16 - RING-EXCHANGE ", + " FP8 - PIPELINE ", + " FP8 - RING-EXCHANGE ", ], ) -def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p): +def test_split_reduce_scatter_overlaps(fp8, p2p): """ Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False) - - -@pytest.mark.parametrize( - "ag_type,rs_type,p2p,fp8_out", - [ - (0, 0, False, False), - (0, 1, False, False), - (0, 1, False, True), - (0, 2, False, False), - (0, 2, False, True), - (0, 0, True, False), - (0, 0, True, True), - (1, 0, True, False), - (1, 0, True, True), - ], - ids=[ - " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", - ], -) -def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): - """ - Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with - direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. - """ - os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type) - os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type) - _run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False) + _run_gemm_with_overlap("RS", False, p2p, False, fp8) @pytest.mark.parametrize( @@ -223,12 +170,12 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): ("RS", True, 8), ], ids=[ - "ALL-GATHER - BF16 - 1 connections", + "ALL-GATHER - BF16 - 1 connections", "REDUCE-SCATTER - BF16 - 1 connections", - "REDUCE-SCATTER - FP8 - 1 connections", - "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", "REDUCE-SCATTER - BF16 - 8 connections", - "REDUCE-SCATTER - FP8 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], ) def test_bulk_overlaps(comm_type, fp8, connections): @@ -242,38 +189,48 @@ def test_bulk_overlaps(comm_type, fp8, connections): " 9.0 (HOPPER ARCH)." ) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_gemm_with_overlap(comm_type, True, False, False, fp8) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" else: - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_gemm_with_overlap(comm_type, True, False, False, fp8) +@pytest.mark.parametrize("fp8", (False, True), ids=[" BF16 ", " FP8 "]) @pytest.mark.parametrize( - "layer_type,linear_parallel_mode", - ( - [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")] - + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) - ), - ids=( - [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "] - + [(" " + layer.__name__ + " ") for layer in TE_LAYERS] - ), -) -@pytest.mark.parametrize( - "fp8,fp8_init", + "layer_type,linear_parallel_mode,overlap_rs_dgrad", [ - (False, False), - (True, False), - (True, True), - ], + (te.Linear.__name__, "row", False), + (te.Linear.__name__, "column", False), + (te.Linear.__name__, "column", True), + (te.LayerNormLinear.__name__, "row", False), + (te.LayerNormLinear.__name__, "column", False), + (te.LayerNormLinear.__name__, "column", True), + ] + + list( + zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + [None] * len(TE_LAYERS[2:]) * 2, + [False, True] * len(TE_LAYERS[2:]), + ) + ), ids=[ - " BF16 GEMM - BF16 PARAMS ", - " FP8 GEMM - BF16 PARAMS ", - " FP8 GEMM - FP8 PARAMS ", + f" {te.Linear.__name__} - ROW-PARALLEL ", + f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", + ] + + [ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), + ) ], ) -def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): +def test_layers_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 598859b826..c8ef7687fa 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -5,27 +5,38 @@ from __future__ import annotations import argparse +from collections.abc import Iterable import functools import itertools import os import pathlib import subprocess import sys +from typing import Optional import pytest import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex -# Check if FP8 is supported + +# Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +quantization_list: list[Optional[str]] = [None] +if fp8_available: + quantization_list.append("fp8") +if mxfp8_available: + quantization_list.append("mxfp8") @functools.cache @@ -66,22 +77,18 @@ def make_reference_and_test_tensors( in Transformer Engine operations. """ - - # Random data ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - - # Make copy of tensor + test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(ref) - else: - test = ref.to(device=test_device, dtype=test_dtype) - if test.data_ptr() == ref.data_ptr(): - test = test.clone() - - # Make sure reference and test tensors represent exact same values + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + test = quantizer(test) + elif test.data_ptr() == ref.data_ptr(): + test = test.clone() ref.copy_(test) - - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test @@ -120,6 +127,21 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: raise ValueError(f"Unsupported dtype ({dtype})") +def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name == "fp8": + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def _test_all_reduce( *, local_size: int = 17, @@ -293,17 +315,16 @@ def _test_reduce_scatter( def _test_basic_linear( *, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + local_batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_grad_output: bool = False, + quantization: Optional[str] = None, + quantized_weight: bool = False, tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + quantized_compute = quantization is not None # Distributed process group process_group = world_group() @@ -313,10 +334,13 @@ def _test_basic_linear( # Tensor dimensions local_out_features, local_in_features = local_weight_shape out_features, in_features = local_out_features, local_in_features + batch_size = local_batch_size if tensor_parallel_mode == "column": out_features *= world_size elif tensor_parallel_mode == "row": in_features *= world_size + if sequence_parallel: + batch_size *= world_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -326,21 +350,28 @@ def _test_basic_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) + if isinstance(w_test, QuantizedTensor): + w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=quantized_compute, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -391,7 +422,8 @@ def _test_basic_linear( x_test.requires_grad_() # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -404,7 +436,7 @@ def _test_basic_linear( with torch.no_grad(): op.weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -412,10 +444,8 @@ def _test_basic_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -429,17 +459,16 @@ def _test_basic_linear( def _test_linear( *, bias: bool = True, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + local_batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_grad_output: bool = False, + quantization: Optional[str] = None, + quantized_weight: bool = False, tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + quantized_compute = quantization is not None # Distributed process group process_group = world_group() @@ -449,10 +478,13 @@ def _test_linear( # Tensor dimensions local_out_features, local_in_features = local_weight_shape out_features, in_features = local_out_features, local_in_features + batch_size = local_batch_size if tensor_parallel_mode == "column": out_features *= world_size elif tensor_parallel_mode == "row": in_features *= world_size + if sequence_parallel: + batch_size *= world_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -462,14 +494,19 @@ def _test_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) + if isinstance(w_test, QuantizedTensor): + w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -485,9 +522,11 @@ def _test_linear( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=quantized_compute, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -552,7 +591,8 @@ def _test_linear( x_test.requires_grad_() # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -571,7 +611,7 @@ def _test_linear( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -579,12 +619,8 @@ def _test_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -603,8 +639,8 @@ def _test_fp8_scale_update( amax_history_len: int = 31, amax_compute_algo: str = "max", margin: float = 2, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", tensor_parallel_mode: str = "column", @@ -715,20 +751,12 @@ def ref_amax_and_scale( y_test.backward(dy_test) # Check results - forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - x_fp8_meta = op.get_fp8_meta("input")[forward_key] - w_fp8_meta = op.get_fp8_meta("param")[forward_key] - dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] - x_amax_test = x_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - w_amax_test = w_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - dy_amax_test = dy_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - x_scale_test = x_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - w_scale_test = w_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - dy_scale_test = dy_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - torch.testing.assert_close(x_amax_test, x_amax_ref) - torch.testing.assert_close(w_amax_test, w_amax_ref) - torch.testing.assert_close(dy_amax_test, dy_amax_ref) + x_quantizer = op.get_quantizer("forward", 0) + w_quantizer = op.get_quantizer("forward", 1) + dy_quantizer = op.get_quantizer("backward", 0) + x_scale_test = x_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) + w_scale_test = w_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) + dy_scale_test = dy_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) torch.testing.assert_close(x_scale_test, x_scale_ref) torch.testing.assert_close(w_scale_test, w_scale_ref) torch.testing.assert_close(dy_scale_test, dy_scale_ref) @@ -755,38 +783,32 @@ def run_parallel_tests() -> None: # Basic linear op for config in itertools.product( - (False, True) if fp8_available else (False,), + quantization_list, ("column", "row"), (False, True), ): if rank == 0: print(f"Running _test_basic_linear with {config=}") - fp8, tensor_parallel_mode, sequence_parallel = config + quantization, tensor_parallel_mode, sequence_parallel = config _test_basic_linear( - fp8_compute=fp8, - fp8_input=fp8, - fp8_weight=fp8, - fp8_grad_output=fp8, + quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, sequence_parallel=sequence_parallel, ) # Linear op for config in itertools.product( - (False, True) if fp8_available else (False,), + quantization_list, ("column", "row"), ): if rank == 0: print(f"Running _test_linear with {config=}") - fp8, tensor_parallel_mode = config + quantization, tensor_parallel_mode = config dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, - fp8_compute=fp8, - fp8_input=fp8, - fp8_weight=fp8, - fp8_grad_output=fp8, + quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, ) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1a6191f06c..7be9cd01ae 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -27,29 +27,31 @@ pytest.skip("Distributed training needs at least 2 GPUs.") fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] -def _run_test(fp8): +def _run_test(quantization): test_path = TEST_ROOT / "run_numerics.py" test_cmd = LAUNCH_CMD + [str(test_path)] - if fp8: - test_cmd += ["--fp8"] + if quantization is not None: + test_cmd += ["--quantization", quantization] - result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) - if result.returncode != 0 or "NUMERICAL CHECK FAILED" in result.stderr.decode(): - raise AssertionError(result.stderr.decode()) + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 all_boolean = [True, False] -@pytest.mark.parametrize("fp8", all_boolean) -def test_distributed(fp8): - if fp8 and not fp8_available: +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8"]) +def test_distributed(quantization): + if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + _run_test(quantization) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 02a85f0ac4..4298d17c9c 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -12,7 +12,7 @@ def get_torch_version(): - """Get pytorch version from __version__""" + """Get PyTorch version from __version__""" def get_torch_version_str(): import torch @@ -22,25 +22,14 @@ def get_torch_version_str(): return PkgVersion(get_torch_version_str()) -if torch.cuda.device_count() < 4: - pytest.skip("FSDP2 test requires at least 4 GPUs.") - -if torch.cuda.device_count() % 2 != 0: - pytest.skip("Number of device should be divided by 2.") - -if not get_torch_version() >= PkgVersion("2.4"): - pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") - fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = torch.cuda.device_count() -LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] def _run_test(fp_init, sharding_dims): - test_path = TEST_ROOT / "run_fsdp2_model.py" - test_cmd = LAUNCH_CMD + [str(test_path)] + test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" + test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] if fp_init: test_cmd += ["--fp8-init"] @@ -50,18 +39,30 @@ def _run_test(fp_init, sharding_dims): test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] else: assert False - result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) - if result.returncode != 0: - raise AssertionError(result.stderr.decode()) + result = subprocess.run(test_cmd, env=os.environ, check=True) -all_boolean = [True, False] -sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] +@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") +@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") +@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+") +@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) +@pytest.mark.parametrize("fp8_init", (False, True)) +def test_distributed(fp8_init, sharding_dims): + # Skip invalid configurations + if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs") -@pytest.mark.parametrize("sharding_dims", sharding_dims) -@pytest.mark.parametrize("fp8_init", all_boolean) -def test_distributed(fp8_init, sharding_dims): if fp8_init and not fp8_available: pytest.skip(reason_for_no_fp8) + _run_test(fp8_init, sharding_dims) + + +def test_dummy() -> None: + """Dummy test + + pytest returns exit code 5 if all tests are skipped. + + """ + pass diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 1fae9e99f2..4a1fd17be7 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -11,7 +11,7 @@ import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -176,6 +176,11 @@ def run_dpa_with_cp( k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) # create flash attention bias if config.attn_bias_type not in ["no_bias", "alibi"]: @@ -206,7 +211,7 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: - dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) + dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: out.backward(dout) @@ -276,7 +281,7 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: - dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) + dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d546118ffb..ff45d1e38f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -20,6 +20,7 @@ MultiheadAttention, RotaryPositionEmbedding, get_attention_backend, + _flash_attn_is_installed, _flash_attn_2_3_plus, _flash_attn_3_is_installed, check_set_window_size, @@ -48,6 +49,12 @@ from transformer_engine.pytorch.utils import get_cudnn_version import transformer_engine_torch as tex from transformer_engine_torch import NVTE_Fused_Attn_Backend +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() @@ -257,11 +264,17 @@ def test_dot_product_attention( pad_between_seqs=pad_between_seqs, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes - if pad_between_seqs and not ( - config.max_seqlen_q != config.max_seqlen_kv - and config.attn_mask_type in ["causal", "padding_causal"] + if ( + pad_between_seqs + and _flash_attn_is_installed + and not ( + config.max_seqlen_q != config.max_seqlen_kv + and config.attn_mask_type in ["causal", "padding_causal"] + ) + and (config.window_size[0] == -1 or _flash_attn_2_3_plus) ): flash_attn_supported = True @@ -1365,13 +1378,18 @@ def _run_transformer_layer( model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), - "fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), + "fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1420,8 +1438,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] + if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( + 9, + 7, + 0, + ): + pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1447,7 +1471,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1499,7 +1523,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: fp8_mha=fp8_mha, ) - with fp8_model_init(enabled=fp8_mha): + with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: PE = RotaryPositionEmbedding(dim=config.head_dim_qk) @@ -1523,12 +1547,26 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: mha = mha.eval() - seqlens_q = torch.full( - [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" - ) - seqlens_kv = torch.full( - [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1565,6 +1603,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, rotary_pos_emb=rotary_pos_emb, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) if is_training: out.backward(out_grad) @@ -1594,13 +1634,29 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): config = model_configs_fp8_vs_f16[model] + # TODO(cyang): think of another way to verify dropout results + # test cuDNN FP8 dropout + # 1. we modify the config here to not affect mha_fp8_vs_f16 tests + # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an + # indirect verification method, we create Q/K/V as all 1s and check if O is all 1s + # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell + # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type: + # if get_device_compute_capability() >= (10, 0): + # config.dropout_p = 0.1 + + if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( + 9, + 7, + 0, + ): + pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1617,17 +1673,19 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): dtype, config, True, qkv_layout, is_training ) - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training - ) + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training + ) atol = 5e-1 rtol = 5e-2 - rmse_tol = 0.1 + rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1637,27 +1695,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rtol, rmse_tol, ) - _error( - fused_attn_fwd_fp8, - fused_attn_fwd_f16, - "fused_attn_fwd_fp8", - "fused_attn_fwd_f16", - atol, - rtol, - rmse_tol, - ) - if is_training: - for i, _ in enumerate(fused_attn_bwd_f16): - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - ) + if config.dropout_p != 0.0: + # test cuDNN FP8 dropout + assert torch.all( + fused_attn_fwd_fp8 == 1 + ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." + else: + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + ) def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): @@ -1696,12 +1760,26 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: dpa = dpa.eval() - seqlens_q = torch.full( - [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" - ) - seqlens_kv = torch.full( - [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1730,7 +1808,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") tensor_shape = [dim_to_num[j] for j in layout.split("_")] - tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") + if config.dropout_p == 0.0: + tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") + else: + # test cuDNN FP8 dropout + tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda") tensor_count = 1 split_dim = 0 for dim, l in enumerate(layout.split("_")): @@ -1766,7 +1848,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - is_first_microbatch=True, ) if is_training: out.backward(out_grad) @@ -1819,7 +1900,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 - rmse_tol = 0.1 + rmse_tol = 0.13 _error( fused_attn_fwd_fp8, unfused_attn_fwd_f16, @@ -1973,7 +2054,9 @@ def forward( workspace: torch.Tensor, is_training: bool, mask_type: str, + quantizers: list[Quantizer], ) -> torch.Tensor: + qkv_dtype = inp.dtype assert inp.dim() == 2 in_features = qkv_weight.shape[-1] @@ -1981,83 +2064,53 @@ def forward( d = in_features // h b = cu_seqlens.numel() - 1 - fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2] + dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3] - inp_fp8, inp_t_fp8 = ext.fp8_cast_transpose_fused( - inp, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - - qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused( - qkv_weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - ) + inp_fp8 = input_quantizer(inp) - M = None - ZInv = None - philox_unpacked = None + qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight) - qkv, _ = ext.fp8_gemm( + qkv, *_ = ext.general_gemm( qkv_weight_fp8, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, inp_fp8, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - torch.uint8, workspace, bias=qkv_bias, - use_bias=True, - out_index=META_QKV, - fp8_meta_tensor=fp8_meta["scaling_fwd"], + out_dtype=qkv_weight_fp8.dtype, + quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, - D_dtype=fp8_dtype_forward, ) qkv = qkv.view(-1, 3, h, d) - qkv_fp16 = ( - ext.cast_from_fp8( - qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16 - ) - .view(b, max_s, 3, h, d) - .contiguous() - ) + qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") if cudnn_frontend_version == 1: qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA - out, aux_ctx_tensors, *rest = fused_attn_fwd( + q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :] + k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :] + v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :] + q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape) + k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape) + v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape) + + out, aux_ctx_tensors = fused_attn_fwd( is_training, max_s, max_s, cu_seqlens, cu_seqlens, - qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], - qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], - qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], - fp8_dtype_forward, + q, + k, + v, + qkv_dtype, FusedAttnBackend["FP8"], - None, - None, - None, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, @@ -2065,20 +2118,18 @@ def forward( attn_bias_type="no_bias", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, + o_quantizer=o_quantizer, + s_quantizer=s_quantizer, ) - M, ZInv, philox_unpacked = aux_ctx_tensors - - ctx.save_for_backward( - inp_t_fp8, - qkv_weight_t_fp8, - workspace, - qkv, - out, - fp8_meta["scaling_fwd"].scale, - fp8_meta["scaling_fwd"].scale_inv, + tensors_to_save, tensor_objects = prepare_for_saving( + q, k, v, inp_fp8, qkv_weight_fp8, workspace, out ) + + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.qkv_dtype = qkv_dtype ctx.fp8_meta = fp8_meta ctx.cu_seqlens = cu_seqlens ctx.p_dropout = p_dropout @@ -2089,58 +2140,46 @@ def forward( ctx.mask_type = mask_type ctx.dtype = inp.dtype + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.S_quantizer = s_quantizer + out = out.view(-1, in_features) # (bs)(hd) - out_fp16 = ext.cast_from_fp8( - out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16 - ) + out_fp16 = out.dequantize() torch.save(out_fp16, "out.pt") # (bs)(hd) return out_fp16 @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): - ( - inp_t_fp8, - qkv_weight_t_fp8, - workspace, - qkv, - out, - fwd_scales, - fwd_scale_inverses, - ) = ctx.saved_tensors - fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + saved_tensors = ctx.saved_tensors + (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved( + ctx.tensor_objects, saved_tensors + ) - proj_dgrad = ext.cast_to_fp8( - grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) # (bs)(hd) + proj_dgrad = ctx.dO_quantizer(grad_output) + fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, ctx.max_s, ctx.cu_seqlens, ctx.cu_seqlens, - qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], - qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], - qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], + q, + k, + v, out, proj_dgrad.view_as(out), - fp8_dtype_forward, + ctx.qkv_dtype, fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], None, None, - fwd_scale_inverses[META_QKV], # d_scale_qkv, - fwd_scale_inverses[META_S], # d_scale_s, - fwd_scale_inverses[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.S_quantizer, + ctx.dP_quantizer, + ctx.dQKV_quantizer, attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, @@ -2149,58 +2188,42 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) dim = 2 if cudnn_frontend_version == 1 else 1 - dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype) - dqkv_shape = list(dq.shape) + dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype) + dqkv_shape = list(dq._data.shape) dqkv_shape.insert(dim, 3) - dqkv_stride = list(dq.stride()) + dqkv_stride = list(dq._data.stride()) dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3)) - dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd + dqkv.set_( + dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride + ) # bs3hd dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size) - dqkv_c_fp16 = ext.cast_from_fp8( - dqkv_c, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - tex.DType.kFloat16, - ) + dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape) + dqkv_c_fp16 = dqkv_c.dequantize() torch.save(dqkv_c_fp16, "dqkv.pt") - qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused( - dqkv_c, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.dtype, - ) + qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer) + dqkv_c._transpose = None + dqkv_c._create_transpose() # QKV DGRAD - qkv_dgrad, _ = ext.fp8_gemm( - qkv_weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, + qkv_dgrad, *_ = ext.general_gemm( + qkv_weight_fp8, dqkv_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - META_DQKV, - fp8_dtype_backward, - ctx.dtype, workspace, + ctx.dtype, use_split_accumulator=_2X_ACC_DGRAD, + layout="NN", ) + # QKV WGRAD - qkv_wgrad, _ = ext.fp8_gemm( - inp_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - dqkv_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - META_DQKV, - fp8_dtype_backward, - ctx.dtype, + qkv_wgrad, *_ = ext.general_gemm( + inp_fp8, + dqkv, workspace, + ctx.dtype, use_split_accumulator=_2X_ACC_WGRAD, + layout="NT", ) return ( @@ -2258,7 +2281,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, None, num_gemms=3) as inp: + with self.prepare_forward(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, @@ -2272,5 +2295,6 @@ def forward( self.workspace, self.training, self.mask_type, + self.quantizers, ) return out diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py new file mode 100644 index 0000000000..61b4a2553c --- /dev/null +++ b/tests/pytorch/test_cpu_offloading.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from contextlib import nullcontext + +import transformer_engine.pytorch as te + +SIZE = 4096 + +models = { + "linear": te.Linear, + "layernorm_mlp": te.LayerNormMLP, + "layernorm_linear": te.LayerNormLinear, +} + + +def _get_input(): + return torch.empty((1, SIZE, SIZE)).cuda() # input size - 1 * 2048 * 2048 * 4b = 16MB + + +def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): + torch.cuda.empty_cache() + model = model_cls(SIZE, SIZE, 1) + + input = _get_input() + if cpu_offload: + offload_context, sync_function = te.get_cpu_offload_context(enabled=True) + else: + offload_context = nullcontext() + sync_function = lambda x: x + + with te.fp8_autocast(enabled=fp8), offload_context: + out = model(input) + out = sync_function(out) + input.data = torch.Tensor() # delete data from input + out.data = torch.Tensor() # delete data from out + del input + del out + torch.cuda.empty_cache() + allocated_memory_mb = torch.cuda.memory_allocated() / 1024**2 + del model + return allocated_memory_mb + + +@pytest.mark.parametrize("fp8", [False, True]) +@pytest.mark.parametrize("model_key", models.keys()) +def test_cpu_offload(fp8, model_key) -> None: + model_cls = models[model_key] + without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) + torch.cuda.empty_cache() + with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) + + assert without_offloading > 30 + assert with_offloading < 10 diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index d92884eaa2..dcdfa771c8 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -22,10 +22,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Record initial RNG state. @@ -49,6 +51,11 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} +fp8_recipes = [ + recipe.DelayedScaling(), + recipe.MXFP8BlockScaling(), +] + # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher @@ -152,6 +159,7 @@ def _test_cuda_graphs( fp8: bool, fp8_params: bool, fp8_weight_caching: bool, + fp8_recipe: recipe.Recipe, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() @@ -162,7 +170,7 @@ def _test_cuda_graphs( fp8_weight_caching = False # Create modules. - with fp8_model_init(enabled=fp8_params): + with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): if module == "transformer": modules = [ TransformerLayer( @@ -244,6 +252,7 @@ def _test_cuda_graphs( num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) elif graph_mode == "individual": # Graph individual modules. @@ -254,6 +263,7 @@ def _test_cuda_graphs( num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) for module in modules ] @@ -270,7 +280,7 @@ def _test_cuda_graphs( for grad_accumulation_step in range(2): input_ = generate_data(model_config, dtype) grad_output = generate_data(model_config, dtype, requires_grad=False) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 @@ -285,6 +295,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def test_make_graphed_callables( *, module: str, @@ -293,6 +304,7 @@ def test_make_graphed_callables( dtype: torch.dtype, fp8: bool, fp8_params: bool, + fp8_recipe: recipe.Recipe, fp8_weight_caching: bool = False, ) -> None: @@ -303,6 +315,8 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -314,6 +328,7 @@ def test_make_graphed_callables( fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) @@ -339,16 +354,19 @@ def test_make_graphed_callables( _test_make_graphed_callables_with_fp8_weight_caching_modules, ) @pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, fp8_params: bool, + fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, dtype=torch.float32, fp8=True, fp8_params=fp8_params, + fp8_recipe=fp8_recipe, fp8_weight_caching=True, ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 96b4ab4967..56b01f1dbc 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -11,8 +11,8 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor import transformer_engine_torch as tex # PyTorch tensor dtypes @@ -42,6 +42,20 @@ def _to_list(x: Union[Iterable, Any]) -> List: fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +def to_float8( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 1.0, +) -> Float8Tensor: + """Cast tensor to FP8""" + quantizer = Float8Quantizer( + scale=torch.full([1], scale, dtype=torch.float32, device="cuda"), + amax=torch.empty([1], dtype=torch.float32, device="cuda"), + fp8_dtype=fp8_dtype, + ) + return quantizer(tensor.cuda()) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFloat8Tensor: @@ -62,10 +76,11 @@ def test_constructor( """Call constructor and perform sanity checks""" dims = _to_list(dims) tensor = Float8Tensor( + shape=dims, + dtype=dtype, data=torch.zeros(dims, device="cuda", dtype=torch.uint8), fp8_dtype=fp8_dtype, fp8_scale_inv=torch.full([1], scale_inv), - dtype=dtype, ) assert list(tensor.size()) == dims, "Incorrect dims" assert tensor.dtype == dtype, "Incorrect nominal dtype" @@ -84,11 +99,7 @@ def _test_quantize_dequantize( x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 # Cast to FP8 and back - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) x_fp8 = x_fp8.dequantize().cpu() # Check results @@ -115,62 +126,6 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) - def test_fp8_meta( - self, - dtype: torch.dtype = torch.float32, - dims: DimsType = 23, - ) -> None: - """Construct Float8Tensor using FP8 metadata and perform basic checks""" - - # Get FP8 metadata from linear module - fp8_dtype = tex.DType.kFloat8E4M3 - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): - module = te.Linear(32, 32) - _ = module(torch.zeros([8, 32], device="cuda")) - fp8_meta = module.fp8_meta - fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - - # Initialize random data - dims = _to_list(dims) - x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - - # Make Float8Tensor - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_meta=fp8_meta, - fp8_meta_index=fp8_meta_index, - ) - x_ref = x_fp8.dequantize() - assert list(x_fp8.size()) == dims, "Incorrect dims" - assert x_fp8.dtype == dtype, "Incorrect nominal dtype" - assert x_fp8.is_cuda, "Incorrect device" - assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" - - # Change FP8 metadata scale - fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 - fp8_meta[fp8_meta_key].scale_inv.fill_(123) - - # Check results - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - with pytest.raises(AssertionError): - # Make sure we are not trivially passing the test - torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) - - # Check if scaling factor is updated after in-place ops - x_fp8 += 0 - fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 - fp8_meta[fp8_meta_key].scale_inv.fill_(321) - assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - y = x_fp8.detach() - y += 0 - assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - def test_basic_ops( self, dims: DimsType = 23, @@ -184,16 +139,8 @@ def test_basic_ops( dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - y_fp8 = Float8Tensor.to_float8( - y_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() y_ref = y_fp8.dequantize() @@ -227,16 +174,8 @@ def test_inplace_ops( dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - y_fp8 = Float8Tensor.to_float8( - y_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() y_ref = y_fp8.dequantize() @@ -260,56 +199,6 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) - def test_transpose( - self, - dims: DimsType, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: float = 0.5, - dtype: torch.dtype = torch.float32, - ) -> None: - """Test transpose""" - - # Initialize random data - dims = _to_list(dims) - x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - x = x_fp8.dequantize() - - # Perform transpose - x_fp8_t = x_fp8.transpose_2d() - x_t = x.transpose(0, 1) - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) - - # Check results - tols = dict(rtol=0, atol=0) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - - # Make sure we are not trivially passing the test - with pytest.raises(AssertionError): - torch.testing.assert_close(x_fp8_t, x, **tols) - - # Caching test - assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." - x_fp8 += 0.5 - x = x_fp8.dequantize() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) - x_t = x.transpose(0, 1) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - - # Inplace update test - x_fp8 += 0.5 - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - x = x_fp8.dequantize() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) - x_t = x.transpose(0, 1) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - def test_serialization( self, dims: DimsType = [2, 3, 5], @@ -321,11 +210,7 @@ def test_serialization( # Initialize random data dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() # Serialize tensor @@ -357,7 +242,7 @@ def test_set_data(self): # Initialize Float8Tensor x0 = torch.zeros(4, dtype=torch.float32) - x = Float8Tensor.to_float8(x0) + x = to_float8(x0) assert isinstance(x, Float8Tensor) assert x0.size() == x.size() == x._data.size() assert x.dtype == torch.float32 @@ -382,7 +267,7 @@ def test_set_data(self): assert x.device == y.device # Set data to Float8Tensor - x0 = Float8Tensor.to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) + x0 = to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) x.data = x0 assert isinstance(x, Float8Tensor) assert x0.size() == x.size() == x._data.size() diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 96acb699ad..507fd3f350 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -11,6 +11,7 @@ from torch import nn from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.utils import is_bf16_compatible @@ -446,7 +447,7 @@ def test_bf16_model_weight_cast(self): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_model_weight_cast(self): dtype = torch.bfloat16 - with fp8_model_init(enabled=True): + with fp8_model_init(enabled=True, recipe=DelayedScaling()): model = MultiheadAttention( hidden_size=1024, num_attention_heads=16, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e2f712cce8..97d48e2aa3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4,7 +4,9 @@ from __future__ import annotations +from collections.abc import Iterable import math +from typing import Optional import pytest import torch @@ -12,7 +14,6 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor @@ -21,11 +22,14 @@ ForwardLinearBiasActivation, ForwardLinearBiasAdd, ) +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] @@ -36,6 +40,38 @@ _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +def maybe_skip_quantization( + quantization: Optional[str], + *, + dims: Optional[Iterable[int] | int] = None, + device: Optional[torch.device | str] = None, +) -> None: + + # Don't skip if there is no quantization + if quantization is None: + return + + # Check if quantization scheme is supported + if quantization == "fp8" and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + if dims is not None: + if not isinstance(dims, Iterable): + dims = (dims,) + if quantization == "fp8": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + elif quantization == "mxfp8": + if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: + pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + + # Check if device is supported + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") + + def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: """Estimated numerical error for a datatype @@ -89,7 +125,12 @@ def make_reference_and_test_tensors( ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(test, with_transpose_cache=True) + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + test = quantizer(test) elif test.data_ptr() == ref.data_ptr(): test = test.clone() ref.copy_(test) @@ -98,6 +139,21 @@ def make_reference_and_test_tensors( return ref, test +def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name == "fp8": + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + raise ValueError(f"Unsupported quantization scheme ({name})") + + class TestSequential: """Tests for sequential container""" @@ -239,7 +295,7 @@ def test_fp8_scale_update( ) # Construct model - with te.fp8_model_init(): + with te.fp8_model_init(recipe=recipe): model = te_ops.basic.BasicLinear( size, size, @@ -299,35 +355,30 @@ def test_fp8_scale_update( w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin) dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin) - forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - w_scale = model.get_fp8_meta("param")[forward_key].scale - x_scale = model.get_fp8_meta("input")[forward_key].scale - dy_scale = model.get_fp8_meta("grad_output")[backward_key].scale + w_scale = model.get_quantizer("forward", 1).scale + x_scale = model.get_quantizer("forward", 0).scale + dy_scale = model.get_quantizer("backward", 0).scale torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref)) torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref)) torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref)) @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_dtype_cast( self, *, - size: int = 16, + size: int = 32, init_dtype: torch.dtype, final_dtype: torch.dtype, device: torch.device = "cuda", - fp8_weight: bool, + quantization: Optional[str], ) -> None: """Check dtype cast functions""" # Skip invalid configurations - if fp8_weight: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data dtype = torch.float32 @@ -339,11 +390,11 @@ def test_dtype_cast( (size, size), test_dtype=dtype, test_device=device, - test_is_fp8=fp8_weight, + test_is_fp8=with_quantization, ) # Construct operation - with te.fp8_model_init(enabled=fp8_weight): + with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) with torch.no_grad(): op.weight.copy_(w_test) @@ -358,7 +409,7 @@ def test_dtype_cast( op.bfloat16() # Check weights - assert isinstance(op.weight, Float8Tensor) == fp8_weight + assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) @@ -378,29 +429,27 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_pyt_autocast( self, *, - size: int = 16, + size: int = 32, model_dtype: torch.dtype, autocast_dtype: torch.dtype, device: torch.device = "cuda", - fp8_weight: bool = False, - fp8_compute: bool, + quantization: Optional[str], + quantized_weights: bool = False, ) -> None: """Test with PyTorch autocast""" device = torch.device(device) # Skip invalid configurations - if fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization) # Construct operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) # Check forward and backward pass @@ -410,7 +459,7 @@ def test_pyt_autocast( device=device, requires_grad=True, ) - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with torch.autocast(device_type=device.type, dtype=autocast_dtype): y = op(x) y.backward(torch.zeros_like(y)) @@ -419,11 +468,11 @@ def test_pyt_autocast( assert op.weight.grad.dtype == model_dtype # Check forward and backward pass (swapped context order) - if fp8_compute: + if quantized_compute: x.grad = None op.weight.grad = None with torch.autocast(device_type=device.type, dtype=autocast_dtype): - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y = op(x) y.backward(torch.zeros_like(y)) assert y.dtype == autocast_dtype @@ -505,19 +554,14 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize( - "memory_format", - (torch.contiguous_format, torch.channels_last), - ) @pytest.mark.parametrize("fp8", (False, True)) def test_reshape( self, *, shapes: tuple[Iterable[int], Iterable[int]], dtype: torch.dtype, - device: torch.device, - memory_format: torch.memory_format, + device: torch.device = "cuda", + memory_format: torch.memory_format = torch.contiguous_format, fp8: bool, ) -> None: in_shape, out_shape = shapes @@ -634,19 +678,23 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) - def test_cast_float8( + def test_quantize( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda", + quantization: str, cast_forward: bool, cast_backward: bool, ) -> None: - """FP8 cast""" + """Quantize""" + + # Skip invalid configurations + maybe_skip_quantization(quantization) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -656,7 +704,7 @@ def test_cast_float8( requires_grad=False, test_is_fp8=True, ) - x_test = x_test.from_float8().requires_grad_() + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, @@ -664,7 +712,7 @@ def test_cast_float8( requires_grad=False, test_is_fp8=True, ) - dy_test = dy_test.from_float8() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -672,16 +720,14 @@ def test_cast_float8( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) + recipe = make_recipe(quantization) with te.fp8_autocast(fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert is_float8_tensor(y_test) == cast_forward - assert is_float8_tensor(x_test.grad) == cast_backward + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -697,12 +743,13 @@ def _test_basic_linear( in_shape: Iterable[int] = (32, -1), dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, - fp8_grad_output: bool = False, - fp8_grad_input: bool = False, + quantization: Optional[str] = None, + quantized_compute: bool = False, + quantized_input: bool = False, + quantized_weight: bool = False, + quantized_output: bool = False, + quantized_grad_output: bool = False, + quantized_grad_input: bool = False, accumulate_into_main_grad: bool = False, ) -> None: """Helper function for tests with GEMM""" @@ -713,50 +760,50 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_compute or fp8_input or fp8_weight or fp8_output or fp8_grad_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization == "fp8" and quantized_output and not quantized_compute: pytest.skip("FP8 output is only supported with FP8 GEMMs") - if fp8_grad_input and not fp8_compute: + if quantization == "fp8" and quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + if quantization == "mxfp8" and quantized_output: + pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") + if quantization == "mxfp8" and quantized_grad_input: + pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=(quantized_compute or quantized_input), ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=(quantized_compute or quantized_grad_output), requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref.backward(dy_ref) # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -769,14 +816,11 @@ def _test_basic_linear( del w_test op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) forward = te_ops.Sequential( - te_ops.Quantize(forward=fp8_input, backward=fp8_grad_input), + te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input), op, - te_ops.Quantize(forward=fp8_output, backward=fp8_grad_output), - ) - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, + te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), ) - with te.fp8_autocast(enabled=fp8_compute, fp8_recipe=recipe): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -784,10 +828,8 @@ def _test_basic_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute or fp8_output or fp8_grad_input: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute or quantized_output or quantized_grad_input: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -813,10 +855,10 @@ def _test_basic_linear( ) torch.testing.assert_close(dw_test, w_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) + @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -824,7 +866,7 @@ def test_basic_linear( weight_shape: tuple[int, int], in_shape: Iterable[int], dtype: torch.dtype, - fp8_compute: bool, + quantization: Optional[str], accumulate_into_main_grad: bool, ) -> None: """GEMM""" @@ -832,52 +874,55 @@ def test_basic_linear( weight_shape=weight_shape, in_shape=in_shape, dtype=dtype, - fp8_compute=fp8_compute, + quantization=quantization, + quantized_compute=quantization is not None, accumulate_into_main_grad=accumulate_into_main_grad, ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_input", (False, True)) - def test_basic_linear_fp8( + @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_input", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("quantized_output", (False, True)) + @pytest.mark.parametrize("quantized_grad_output", (False, True)) + @pytest.mark.parametrize("quantized_grad_input", (False, True)) + def test_basic_linear_quantized( self, *, - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, - fp8_output: bool, - fp8_grad_output: bool, - fp8_grad_input: bool, + quantization: str, + quantized_compute: bool, + quantized_input: bool, + quantized_weight: bool, + quantized_output: bool, + quantized_grad_output: bool, + quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" self._test_basic_linear( dtype=torch.bfloat16, - fp8_compute=fp8_compute, - fp8_input=fp8_input, - fp8_weight=fp8_weight, - fp8_output=fp8_output, - fp8_grad_output=fp8_grad_output, - fp8_grad_input=fp8_grad_input, + quantization=quantization, + quantized_compute=quantized_compute, + quantized_input=quantized_input, + quantized_weight=quantized_weight, + quantized_output=quantized_output, + quantized_grad_output=quantized_grad_output, + quantized_grad_input=quantized_grad_input, ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, *, bias: bool, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool, + quantization: Optional[str], + quantized_weight: bool, ) -> None: """GEMM + bias""" @@ -887,31 +932,25 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -932,7 +971,8 @@ def test_linear( y_ref.backward(dy_ref) # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.Linear( in_features, out_features, @@ -946,7 +986,7 @@ def test_linear( op.bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -954,10 +994,8 @@ def test_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -970,12 +1008,11 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) - @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("weight_shape", ((7, 2), (32,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_layer_norm( self, *, @@ -985,8 +1022,7 @@ def test_layer_norm( device: torch.device = "cuda", eps: float = 0.3, zero_centered_gamma: bool, - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Layer norm""" @@ -994,18 +1030,13 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, ) w_ref, w_test = make_reference_and_test_tensors( weight_shape, @@ -1047,17 +1078,19 @@ def test_layer_norm( op.bias.copy_(b_test) del w_test del b_test + quantized_compute = quantization is not None + recipe = make_recipe(quantization) forward = te_ops.Sequential( op, - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1145,12 +1178,11 @@ def test_layer_norm_autocast( torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype)) torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype)) - @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) - @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("weight_shape", ((19,), (64,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_rmsnorm( self, *, @@ -1160,8 +1192,7 @@ def test_rmsnorm( device: torch.device = "cuda", eps: float = 0.3, zero_centered_gamma: bool, - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Layer norm""" @@ -1169,18 +1200,13 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, ) w_ref, w_test = make_reference_and_test_tensors( weight_shape, @@ -1214,17 +1240,19 @@ def test_rmsnorm( with torch.no_grad(): op.weight.copy_(w_test) del w_test + quantized_compute = quantization is not None + recipe = make_recipe(quantization) forward = te_ops.Sequential( op, - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1363,10 +1391,9 @@ def test_make_extra_output( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) - @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_activation( self, *, @@ -1374,8 +1401,7 @@ def test_activation( out_shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Activation functions""" @@ -1385,19 +1411,19 @@ def test_activation( in_shape[-1] *= 2 # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, + test_is_fp8=quantized_compute, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, @@ -1425,6 +1451,7 @@ def test_activation( y_ref.backward(dy_ref) # Implementation with fusible operation + recipe = make_recipe(quantization) make_op = dict( gelu=te_ops.GELU, relu=te_ops.ReLU, @@ -1434,16 +1461,18 @@ def test_activation( )[activation] forward = te_ops.Sequential( make_op(), - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) + if activation == "relu": + tols = {"atol": 0, "rtol": 0} # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1452,16 +1481,18 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_input", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( self, *, - out_shape: Iterable[int] = (16, 16), + out_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device = "cuda", - fp8_output: bool, - fp8_grad_input: bool, + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, ): # Tensor dimensions @@ -1469,19 +1500,10 @@ def test_swiglu( in_shape[-1] *= 2 # Skip invalid configurations - fp8 = fp8_output or fp8_grad_input - if fp8: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - - # FP8 recipe - fp8_recipe = None - if fp8_grad_input: - fp8_recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1502,18 +1524,19 @@ def test_swiglu( y_ref.backward(dy_ref) # Implementation with fusible operation + recipe = make_recipe(quantization) forward = te_ops.Sequential( - te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.SwiGLU(), - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantize_forward, backward=False), ) - with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1533,12 +1556,11 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("weight_shape", ((32, 48), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (4, 2, 10, -1))) + @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, *, @@ -1547,9 +1569,8 @@ def test_forward_linear_bias_activation( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, + quantization: Optional[str], + quantized_weight: bool, ) -> None: """Forward GEMM + bias + activation""" @@ -1559,18 +1580,9 @@ def test_forward_linear_bias_activation( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output" @@ -1581,13 +1593,16 @@ def test_forward_linear_bias_activation( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1608,7 +1623,8 @@ def test_forward_linear_bias_activation( y_ref.backward(dy_ref) # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1624,7 +1640,7 @@ def test_forward_linear_bias_activation( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -1637,12 +1653,8 @@ def test_forward_linear_bias_activation( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1657,19 +1669,17 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_forward_linear_bias_add( self, *, bias: bool, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, + quantization: Optional[str], + quantized_weight: bool = False, ) -> None: """Forward GEMM + bias + add""" @@ -1679,21 +1689,10 @@ def test_forward_linear_bias_add( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_output or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: - pytest.skip("FP8 output requires FP8 compute") - if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") # Random data @@ -1701,13 +1700,16 @@ def test_forward_linear_bias_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x1_test, QuantizedTensor): + with torch.no_grad(): + x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1720,7 +1722,6 @@ def test_forward_linear_bias_add( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_output, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1734,7 +1735,8 @@ def test_forward_linear_bias_add( y_ref.backward(dy_ref) # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1751,7 +1753,7 @@ def test_forward_linear_bias_add( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -1764,12 +1766,8 @@ def test_forward_linear_bias_add( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1785,18 +1783,16 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_backward_linear_add( self, *, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, + quantization: Optional[str], + quantized_weight: bool = False, ) -> None: """Backward dgrad GEMM + add""" @@ -1806,21 +1802,10 @@ def test_backward_linear_add( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_output or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: - pytest.skip("FP8 output requires FP8 compute") - if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") # Random data @@ -1828,13 +1813,16 @@ def test_backward_linear_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -1855,7 +1843,8 @@ def test_backward_linear_add( (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.MakeExtraOutput(), te_ops.Linear( @@ -1869,7 +1858,7 @@ def test_backward_linear_add( with torch.no_grad(): model[1].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y1_test, y2_test = model(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -1882,12 +1871,8 @@ def test_backward_linear_add( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[1].weight._fp8_dtype - if is_float8_tensor(model[1].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e9b6303933..2401f3ca95 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -13,7 +13,11 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + fp8_autocast, + fp8_model_init, +) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -35,13 +39,16 @@ Fp8Unpadding, ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.common import recipe import transformer_engine_torch as tex -# Only run FP8 tests on H100. +# Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -90,6 +97,11 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq mask_types = ["causal", "no_mask"] +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), +] + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -450,7 +462,8 @@ def __init__( self.fc2 = nn.Linear(ffn_hidden_size, hidden_size) def forward(self, x): - return self.fc2(self.gelu(self.fc1(self.ln(x)))) + t = self.gelu(self.fc1(self.ln(x))) + return self.fc2(t) class TorchGPT(nn.Module): @@ -480,7 +493,9 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): +def _test_e2e_selective_recompute( + bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False +): reset_rng_states() FP8GlobalStateManager.reset() @@ -488,7 +503,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -515,7 +530,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -536,18 +551,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] outputs = _test_e2e_selective_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=False + bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False ) outputs_recompute = _test_e2e_selective_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=True + bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True ) # Check that results match @@ -556,6 +574,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par tols["atol"] = 1e-4 if fp8 or fp8_model_params: tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): torch.testing.assert_close( test, @@ -566,7 +585,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par def _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True + bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True ): reset_rng_states() FP8GlobalStateManager.reset() @@ -575,7 +594,7 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -603,7 +622,7 @@ def _test_e2e_full_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: te_out = te_checkpoint( block, @@ -641,11 +660,16 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant): +def test_gpt_full_activation_recompute( + dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant +): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] @@ -654,10 +678,24 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" outputs, names = _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant + bs, + dtype, + config, + fp8, + recipe, + fp8_model_params, + recompute=False, + use_reentrant=use_reentrant, ) outputs_recompute, _ = _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant + bs, + dtype, + config, + fp8, + recipe, + fp8_model_params, + recompute=True, + use_reentrant=use_reentrant, ) if not use_reentrant: @@ -741,7 +779,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= del block block = _test_e2e_checkpointing_get_model(config, dtype) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) reset_rng_states() for p in block.parameters(): @@ -1267,9 +1305,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere torch.half: 2e-3, torch.bfloat16: 2e-2, } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } # Check output. - assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype]) if model == "small": atol = { @@ -1335,8 +1378,14 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): torch.bfloat16: 5e-2, } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } + # Check output. - assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype]) # Check gradients, only for small model rtol = { @@ -1351,7 +1400,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): +def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): reset_rng_states() if fp8: FP8GlobalStateManager.reset() @@ -1365,16 +1414,22 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False inp_hidden_states.retain_grad() if num_gemms > 1: - m = config.seq_len // 16 + split_size = 1 + if fp8: + if recipe.delayed(): + split_size = 16 + if recipe.mxfp8(): + split_size = 128 + m = config.seq_len // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 + m_splits = m_splits * split_size assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms else: m_splits = torch.tensor([config.seq_len]) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): m_splits = m_splits * bs out = block(inp_hidden_states, m_splits.tolist()) @@ -1401,18 +1456,23 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None + dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches + pytest.skip("MXFP8 unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1442,9 +1502,11 @@ def test_grouped_linear_accuracy( sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) - outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8) outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, fp8 + sequential_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) # Shoule be bit-wise match @@ -1453,7 +1515,8 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) -def test_grouped_linear_accuracy_parallel_mode(parallel_mode): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, @@ -1461,12 +1524,14 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): bs=2, model="126m", fp8=True, + recipe=recipe, fp8_model_params=True, parallel_mode=parallel_mode, ) -def test_grouped_linear_accuracy_single_gemm(): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, @@ -1474,11 +1539,12 @@ def test_grouped_linear_accuracy_single_gemm(): bs=2, model="126m", fp8=True, + recipe=recipe, fp8_model_params=True, ) -def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): """Padding tensor shapes to multiples of 16.""" @@ -1546,7 +1612,7 @@ def _generate_random_numbers(n, total_sum): m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): out = block(inp_hidden_states, m_splits) else: @@ -1575,18 +1641,23 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None + dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches + pytest.skip("MXFP8 unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -1597,7 +1668,7 @@ def test_padding_grouped_linear_accuracy( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1619,10 +1690,10 @@ def test_padding_grouped_linear_accuracy( ) outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, fp8 + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, fp8 + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) # Shoule be bit-wise match @@ -1734,7 +1805,7 @@ def test_gpt_cuda_graph(dtype, bs, model): assert_allclose(grads, graphed_grads, 1e-3) -def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): +def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): reset_rng_states() FP8GlobalStateManager.reset() @@ -1742,7 +1813,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_model_params): + with fp8_model_init(enabled=fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -1769,7 +1840,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=True): + with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -1785,14 +1856,17 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -def test_gpt_fp8_parameters(dtype, bs, model): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_gpt_fp8_parameters(dtype, bs, model, recipe): if not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] - outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) - outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) + outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) + outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe) # Check that results match tols = dict(rtol=0.125, atol=0.0675) @@ -2073,23 +2147,24 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): out_ref = [o.clone() for o in out] for i in range(z): - gemm( + general_gemm( A[i], B[i], - dtype, get_workspace(), + dtype, grad=grad, accumulate=accumulate, layout=layout, out=out_ref[i], ) - grouped_gemm( + general_grouped_gemm( A, - B, - out, + list(B), + list(out), dtype, get_multi_stream_cublas_workspace(), + m_splits=[k] * n, # TODO, not sure grad=grad, accumulate=accumulate, layout=layout, @@ -2124,64 +2199,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): out_ref = [o.clone() for o in out] # fp8 should be robust enough to this fake scale - scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda") - scale_inv = 1 / scale - amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda") + scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() + amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - A_fp8 = [ - torch.ops.tex_ts.cast_to_fp8_ts( - A[i], - scale, - amax, - scale_inv, - i, # fp8 meta tensor index + a_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), tex.DType.kFloat8E4M3, ) - for i in range(z) + for _ in range(z) ] - B_fp8 = [ - torch.ops.tex_ts.cast_to_fp8_ts( - B[i], - scale, - amax, - scale_inv, - z + i, # fp8 meta tensor index - fp8_dtype, + b_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, ) - for i in range(z) + for _ in range(z) ] - fp8_grouped_gemm( - A_fp8, - [scale_inv], - 0, # A_offset - tex.DType.kFloat8E4M3, - B_fp8, - scale_inv, - z, # B_offset - fp8_dtype, - out, - dtype, - get_multi_stream_cublas_workspace(), - accumulate=accumulate, - ) + A_fp8 = [] + B_fp8 = [] + + for i in range(z): + A_fp8.append(a_quantizers[i](A[i])) + B_fp8.append(b_quantizers[i](B[i])) # baseline for i in range(z): - fp8_gemm( + general_gemm( A_fp8[i], - scale_inv, - i, - tex.DType.kFloat8E4M3, B_fp8[i], - scale_inv, - z + i, - fp8_dtype, - dtype, get_workspace(), + dtype, out=out_ref[i], accumulate=accumulate, ) + general_grouped_gemm( + A_fp8, + B_fp8, + out, + dtype, + get_multi_stream_cublas_workspace(), + m_splits=[k] * m_splits, + accumulate=accumulate, + ) # should be bit-wise match for o, o_ref in zip(out, out_ref): diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py deleted file mode 100644 index 46e888462a..0000000000 --- a/tests/pytorch/test_onnx_export.py +++ /dev/null @@ -1,1562 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -This file contains tests for exporting TransformerEngine models to ONNX. - -The purpose of these tests is validation that TE models are converted to their correct ONNX -representation. Toward this end, each test captures the output of a TE module forward pass, -converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and -validate the output against TE's output. - -Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented -using custom ORT operations. - -To run many repetitive tests use pytest-loop: - $ python3 -m pip install pytest-loop - $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm - -For reproducability use: torch.manual_seed(0) -""" - -import os -import tempfile -import pytest -import warnings -import numpy as np -import onnxruntime as ort -import torch -from torch import nn as nn -from typing import Optional, Union, Tuple, List -import transformer_engine.pytorch as te -from transformer_engine.common import recipe -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import ( - gemm, - fp8_gemm, - gelu, - cast_to_fp8, - cast_from_fp8, -) -from transformer_engine.pytorch.module.base import get_workspace -import transformer_engine.pytorch.cpp_extensions as texcpp -import transformer_engine.pytorch.softmax as softmax_defs -from transformer_engine.pytorch.utils import get_default_init_method -from transformer_engine.pytorch.export import is_in_onnx_export_mode -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager - -# Global test configuration knobs. - -# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance). -SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0"))) - -if SAVE_TEST_IO: - from polygraphy.json import save_json - from polygraphy.comparator import RunResults - -# The directory where generated ONNX test models are stored. -NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR") -NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( - tempfile.gettempdir(), "./gen_onnx_models" -) - - -# The directory where this file is stored. -TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) - -# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. -TRILU_OPSET = 14 -# Opset used in the ONNX files generated by the tests. -OPSET = 17 -assert OPSET >= TRILU_OPSET - -# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). -ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so") - -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - -supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] - -all_normalizations = ["LayerNorm", "RMSNorm"] - - -@pytest.fixture() -def seed_default_rng(): - """Reseed the PRNG for test reproducibility""" - torch.manual_seed(1234) - - -@pytest.fixture() -def set_max_seq_len(max_seq_len=128): - """Set the maximum sequence length that can be used for attention masking""" - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" - - -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - FP8GlobalStateManager.reset() - - -def create_fp8_recipe(): - return recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) - - -def do_export( - model: torch.nn.Module, - inp: torch.Tensor, - fname: str, - use_fp8: bool = True, - opset: int = OPSET, - input_names: List[str] = None, - output_names: List[str] = None, - dynamic_axes: List[str] = None, -): - """Export to ONNX""" - fp8_recipe = create_fp8_recipe() - input_names = input_names or ["input"] - output_names = output_names or ["output"] - - with torch.inference_mode(), te.fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe - ), warnings.catch_warnings(): - warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") - - model.cuda().eval() - os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - - inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) - assert len(inps) == len(input_names) - inds_to_del = [i for i in range(len(inps)) if inps[i] is None] - input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del] - - with te.onnx_export(True): - torch.onnx.export( - model, - inps, - fname, - verbose=True, - dynamic_axes=dynamic_axes, - opset_version=opset, - input_names=input_names, - output_names=output_names, - do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, - ) - - -def to_numpy(tensor): - if isinstance(tensor, torch.Tensor): - if tensor.dtype == torch.bfloat16: - tensor = tensor.type(torch.float32) - tensor = tensor.detach().cpu().numpy() - return tensor - - -def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): - """Initialize the FP8 quantization scales in module""" - NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. - nb_total_scales = num_gemms * NB_SCALES_PER_GEMM - module.init_fp8_metadata(num_gemms) - module.fp8_meta["scaling_fwd"].scale = ( - torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale - ) - module.fp8_meta["scaling_fwd"].scale_inv = ( - torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale - ) - - -def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): - """Transformer Engine forward propagation.""" - fp8_recipe = create_fp8_recipe() - with torch.inference_mode(), te.fp8_autocast( - enabled=is_fp8, fp8_recipe=fp8_recipe - ), warnings.catch_warnings(): - te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) - if not isinstance(te_outputs, tuple): - te_outputs = (te_outputs,) - return te_outputs - - -def compare_outputs( - onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname -): - """Compare ORT and TE outputs.""" - assert len(onnx_outputs) == len(te_outputs) - # Compare ORT and PyTorch outputs. - for onnx_output, te_output in zip(onnx_outputs, te_outputs): - # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) - te_output = to_numpy(te_output) - onnx_output = to_numpy(onnx_output) - ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) - mismatches = ac.nonzero() - mismatched_ids = [loc for loc in zip(*mismatches)] - if mismatched_ids: - # Log some information in case of error. - print("*" * 100) - nb_errors = len(mismatched_ids) - nb_vals = min(nb_errors, max_errors_printed) - print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})") - print(f"Showing first {nb_vals} errors (ONNX -- TE):") - abs_err = np.abs(onnx_output - te_output) - errors = abs_err[mismatches] - for loc in mismatched_ids[:nb_vals]: - ref = te_output[loc] - print( - f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >" - f" {atol + rtol * abs(ref)}" - ) - print(f"Max error: {np.max(errors)}") - if nb_errors > allow_cnt_errors: - raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") - - -def serialize_inputs_outputs( - fname: str, - inputs: Union[Tuple[torch.Tensor], torch.Tensor], - te_outputs: List[torch.Tensor], - input_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, -): - if not SAVE_TEST_IO: - return - - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - - input_names = input_names or ["input"] - output_names = output_names or ["output"] - inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) - named_inputs = zip(input_names, inputs) - input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] - json_fname = fname[: -len(".onnx")] + "_inputs.json" - save_json(input_data, json_fname, description="custom input data") - - json_fname = fname[: -len(".onnx")] + "_output.json" - named_outputs = zip(output_names, te_outputs) - output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} - custom_outputs = RunResults() - custom_outputs.add([output_data], runner_name="custom_runner") - custom_outputs.save(json_fname) - - -def validate_result( - fname: str, - inps: Union[Tuple[torch.Tensor], torch.Tensor], - model: torch.nn.Module, - atol: float = 1.0e-8, # np.isclose default atol - rtol: float = 1.0e-5, # np.isclose default rtol - max_errors_printed: int = 10, - is_fp8: bool = False, - allow_cnt_errors: int = 0, - input_names: List[str] = None, - output_names: List[str] = None, - te_outputs: List[torch.Tensor] = None, -): - """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX - representation using ONNX Runtime (ORT) and ensure they are close. - - The purpose of the output comparison is to validate that TE models are converted to - their correct ONNX representation by testing that TE and ORT outputs match within some - small threshold (allowing for finite precision errors). - - Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring, - a very small number (0-3) of outliers. This is fine to do because these outliers are due to - small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX - representation (the tests assume both ORT or TE kernels are correct). - - Argument `te_outputs` can be used to provide pre-computed TE outputs. - """ - - def create_ort_session(fname: str, is_fp8: bool): - def load_custom_ops(session_opts: ort.SessionOptions): - """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" - if not os.path.exists(ORT_CUSTOM_OPS_LIB): - raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}") - session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB) - print("registered custom FP8 Q/DQ ops!") - - """Create an ONNX Runtime session for validation.""" - kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]} - if is_fp8: - sess_options = ort.SessionOptions() - load_custom_ops(sess_options) - kwargs["sess_options"] = sess_options - - s = ort.InferenceSession(fname, **kwargs) - return s - - def create_ort_input_dict(session, inputs): - inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) - input_names = [x.name for x in session.get_inputs()] - inps = [to_numpy(x) for x in inputs if x is not None] - inp_dict = dict(zip(input_names, inps)) - return inp_dict - - input_names = input_names or ["input"] - output_names = output_names or ["output"] - - # Run ORT session and TE model. - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - if not te_outputs: - te_outputs = te_infer(model, inps, is_fp8) - ort_s = create_ort_session(fname, is_fp8) - input_feed = create_ort_input_dict(ort_s, inps) - onnx_outputs = ort_s.run(None, input_feed=input_feed) - compare_outputs( - onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname - ) - - -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - -def dtype2str(dtype: torch.dtype, fake_bf16_io=False): - if fake_bf16_io: - assert dtype == torch.bfloat16 - return "_fake_bf16" - return { - torch.float32: "_fp32", - torch.float16: "_fp16", - torch.bfloat16: "_bf16", - }[dtype] - - -def as_te_type(dtype: torch.dtype): - return { - torch.float32: tex.DType.kFloat32, - torch.float16: tex.DType.kFloat16, - torch.bfloat16: tex.DType.kBFloat16, - }[dtype] - - -def get_attn_mask_str(use_mask, attn_mask_type): - # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. - if attn_mask_type is None: - return "_mask" if use_mask else "_no-mask" - attn_mask_str = "_arbitrary-no-mask" - attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str - attn_mask_str = ( - "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str - ) - return attn_mask_str - - -class FP8GemmModule(nn.Module): - def __init__(self, precision, use_bias, gelu, scale_factors, hidden_size, out_features): - super().__init__() - self.use_bias = use_bias - self.gelu = gelu - self.precision = precision - - self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT - nb_inp_scales, nb_weight_scales = 1, out_features - act_scale_factor, weight_scale_factor = scale_factors - self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) - self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) - - bias_size = nb_weight_scales - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") - - self.inp_type = tex.DType.kFloat8E4M3 - self.weights_type = tex.DType.kFloat8E4M3 - self.outp_type = precision - - def forward(self, inp, weight): - inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) - - weight_fp8 = cast_to_fp8( - weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type - ) - - ret, _ = fp8_gemm( - weight_fp8, - self.meta_weight.scale_inv, - self.fp8_tensor_weight, - self.inp_type, - inp_fp8, - self.meta_inp.scale_inv, - self.fp8_tensor_inp, - self.weights_type, - self.outp_type, - get_workspace(), - bias=self.bias, - use_bias=self.use_bias, - use_split_accumulator=False, - ) - return ret - - -""" -Tests cases begin here. -""" - - -@skip_FP8 -@pytest.mark.parametrize("scale_factor", [1, 224]) -@pytest.mark.parametrize( - "precision, atol", - [ - [torch.float32, 1e-7], - [torch.float16, 1e-7], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], - ], -) -def test_export_cast_ops( - seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - class TestFP8_QDQ(nn.Module): - def __init__(self, fake_bf16_io): - super().__init__() - self.fp8_tensor = 0 - self.meta = create_meta(scale_factor) - self.highprec_type = as_te_type(precision) - self.fp8_type = tex.DType.kFloat8E4M3 - self.fake_bf16_io = fake_bf16_io - - def forward(self, inp): - ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type) - - ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - # Set dimensions (these are arbitrary). - in_features = 64 - hidden_size = 256 - inp = torch.randn( - hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision - ) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" - model = TestFP8_QDQ(fake_bf16_io) - - do_export(model, inp, fname) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs) - - -@skip_FP8 -@pytest.mark.parametrize("scale_factor", [448]) -@pytest.mark.parametrize( - "precision, atol", - [ - [torch.float32, 1e-5], - [torch.float16, 1e-5], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], - ], -) -def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - class TestFP8_Gelu(nn.Module): - def __init__(self, fake_bf16_io): - super().__init__() - self.fp8_tensor = 0 - self.meta = create_meta(scale_factor) - self.highprec_type = as_te_type(precision) - self.fp8_type = tex.DType.kFloat8E4M3 - self.fake_bf16_io = fake_bf16_io - - def forward(self, inp): - ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type) - ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - # Set dimensions (these are arbitrary). - in_features = 64 - hidden_size = 256 - inp = torch.randn( - hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision - ) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" - model = TestFP8_Gelu(fake_bf16_io) - do_export(model, inp, fname) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, - inp, - model, - rtol=0, - atol=atol, - is_fp8=True, - allow_cnt_errors=2, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize( - "scale_factors", - [ - ( - 224, - 224, - ), - ], -) -@pytest.mark.parametrize( - "precision, use_fp8, use_bias, use_gelu", - [ - (torch.float32, False, False, False), - (torch.float16, False, False, False), - (torch.bfloat16, False, False, False), - (torch.float32, False, True, False), - (torch.float16, False, True, False), - (torch.bfloat16, False, True, False), - (torch.float32, False, True, True), - (torch.float16, False, True, True), - (torch.bfloat16, False, True, True), - # For FP8 GEMM GeLU is not used. - (torch.float32, True, False, False), - (torch.float16, True, False, False), - (torch.bfloat16, True, False, False), - # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) - (torch.float16, True, True, False), - (torch.bfloat16, True, True, False), - ], -) -def test_export_gemm( - seed_default_rng, - precision, # Precision of inputs, weights, output and bias - use_fp8, - use_bias, - use_gelu, - scale_factors, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - class Test_GEMM(nn.Module): - def __init__(self, precision, use_bias=False, gelu=False): - super().__init__() - self.use_bias = use_bias - self.gelu = gelu - self.precision = precision - bias_size = out_features - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") - - def forward(self, inp, weight): - outp_type = self.precision - - # note: due to logic in lines 104:116 and L129 in cpp_extensions.py - # it appears either bias OR gelu can be activated, not both - ret, _, _ = gemm( - weight, - inp, - outp_type, - get_workspace(), - # test bias - bias=self.bias, - use_bias=self.use_bias, - # test gelu - gelu=self.gelu, - gelu_input=self.gelu_input, - grad=False, # only True for backward pass - accumulate=False, - ) - return ret - - # If gelu is applied then bias must be added, as defined by TE kernel. - if use_gelu: - assert use_bias - # Set dimensions (these are arbitrary). - out_features = 128 - hidden_size = 256 - in_features = 64 - inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) - weight = torch.randn(out_features, in_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - gelu_str = "_gelu" if use_gelu else "" - high_prec_str = dtype2str(precision) - fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" - input_names = ["input", "weight"] - if use_fp8: - model = FP8GemmModule( - precision, use_bias, use_gelu, scale_factors, hidden_size, out_features - ) - do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision != torch.bfloat16: - validate_result( - fname, - (inp, weight), - model, - rtol=1e-2, - atol=2e-2, - is_fp8=True, - input_names=input_names, - te_outputs=te_outputs, - ) - else: - model = Test_GEMM(precision, use_bias, use_gelu) - do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision != torch.bfloat16: - validate_result( - fname, - (inp, weight), - model, - rtol=1e-2, - atol=2e-2, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize("scale_factor", [448, 112]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize( - "use_fp8, precision, atol", - [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2], - ], -) -def test_export_layernorm( - seed_default_rng, - use_fp8: bool, - scale_factor: float, - precision: torch.dtype, - zero_centered_gamma: bool, - atol: float, -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - inp_shape = [64, 32] - - class Test_Layernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - eps = 1e-6 # An arbitrary small value - dtype = torch.float if fake_bf16_io else precision - self.ln = ( - te.LayerNorm( - inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma - ) - .eval() - .cuda() - ) - - def forward(self, inp): - ret = self.ln(inp) - return ret - - class TestFP8_Layernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.bias = torch.zeros( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - - def forward(self, inp): - ret = texcpp.layernorm_fwd_fp8_inf( - inp, - self.weight, - self.bias, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - ret = cast_from_fp8( - ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) - ) - if fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm() - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" - fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" - do_export(model, inp, fname, use_fp8=use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("scale_factor", [448, 112]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize( - "use_fp8, precision, atol", - [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2], - ], -) -def test_export_rmsnorm( - seed_default_rng, - use_fp8: bool, - scale_factor: float, - precision: torch.dtype, - zero_centered_gamma: bool, - atol: float, -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - inp_shape = [64, 32] - - class Test_RMSnorm(nn.Module): - def __init__(self) -> None: - super().__init__() - eps = 1e-6 # An arbitrary small value - dtype = torch.float if fake_bf16_io else precision - self.ln = ( - te.RMSNorm( - inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma - ) - .eval() - .cuda() - ) - - def forward(self, inp): - ret = self.ln(inp) - return ret - - class TestFP8_RMSnorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - - def forward(self, inp): - ret = texcpp.rmsnorm_fwd_fp8_inf( - inp, - self.weight, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - ret = cast_from_fp8( - ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) - ) - if fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - model = TestFP8_RMSnorm() if use_fp8 else Test_RMSnorm() - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" - fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" - do_export(model, inp, fname, use_fp8=use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("scale_factor", [1]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, False), - (torch.float16, True), - # Todo: cannot configure BF16 when bias is disabled (ORT issue?) - (torch.bfloat16, False), - # Todo: cannot configure BF16 when bias is enabled (ORT issue?) - (torch.bfloat16, True), - ], -) -def test_export_linear( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - precision: torch.dtype, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - - class Test_Linear(nn.Module): - def __init__(self, in_features, out_features, use_bias, return_bias, precision): - super().__init__() - self.linear = te.Linear( - in_features, - out_features, - bias=use_bias, - return_bias=return_bias, - params_dtype=precision, - ) - - def forward(self, inp): - ret = self.linear(inp) - return ret - - inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" - with te.fp8_autocast(enabled=use_fp8): - model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( - device="cuda" - ) - if use_fp8: - set_layer_scale(model.linear, scale_factor, num_gemms=1) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - - if precision in (torch.bfloat16,): - return - if not use_fp8: - validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) - else: - validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8, te_outputs=te_outputs) - - -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_linear( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - - inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" - - with te.fp8_autocast(enabled=use_fp8): - model = te.LayerNormLinear( - hidden_size, - 3 * hidden_size, - bias=use_bias, - return_bias=return_bias, - return_layernorm_output=return_layernorm_output, - params_dtype=precision, - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ).to(device="cuda") - if use_fp8: - set_layer_scale(model, scale_factor, num_gemms=1) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16,): - return - if not use_fp8: - validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) - elif precision != torch.bfloat16: - validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs) - - -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_mlp( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - activation: str, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - ffn_hidden_size = 256 - - inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" - with te.fp8_autocast(enabled=use_fp8): - model = te.LayerNormMLP( - hidden_size, - ffn_hidden_size, - bias=use_bias, - return_bias=return_bias, - return_layernorm_output=return_layernorm_output, - params_dtype=precision, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - normalization=normalization, - ).to(device="cuda") - if use_fp8: - set_layer_scale(model, scale_factor, num_gemms=2) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16,): - return - atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) - validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs) - - -@skip_FP8 -@pytest.mark.parametrize( - "precision, use_mask, attn_mask_type", - [ - (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - ], -) -def test_export_core_attention( - seed_default_rng, - set_max_seq_len, - precision: torch.dtype, - use_mask: bool, - attn_mask_type: str, -): - # Set dimensions (these are arbitrary). - seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) - qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) - qkv_format = "sbhd" - - query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - input_names = ["query", "key", "value", "attention_mask"] - attention_mask = None - if use_mask: - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (query_layer, key_layer, value_layer, attention_mask) - - mask_str = get_attn_mask_str(use_mask, attn_mask_type) - high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" - - model = te.attention.DotProductAttention( - num_attention_heads=num_attention_heads, - kv_channels=kv_channels, - attention_dropout=0.5, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, use_fp8=True) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16,): - return - validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs - ) - - -test_configs_multihead_attention = [ - # "use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledSoftmax - (True, "arbitrary"), # calls ScaledMaskedSoftmax -] -test_configs_attention_type = [ - # "input_layernorm, attention_type, fuse_qkv_params" - (True, "self", True), - (False, "self", True), - (True, "self", False), - (False, "self", False), - (True, "cross", True), - (False, "cross", True), - (True, "cross", False), - (False, "cross", False), -] - - -@pytest.mark.parametrize("use_fp8", [False, True]) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type -) -def test_export_multihead_attention( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - use_mask: bool, - attn_mask_type: str, - precision: torch.dtype, - return_layernorm_output: bool, - input_layernorm: bool, - attention_type: str, - fuse_qkv_params: bool, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - hidden_size = 256 - sequence_length = 128 - batch_size = 4 - num_attention_heads = 32 - kv_channels = 8 - attention_dropout = 0.1 - layernorm_epsilon = 1e-5 - init_method = output_layer_init_method = get_default_init_method() - attention_args = ( - hidden_size, - num_attention_heads, - kv_channels, - attention_dropout, - layernorm_epsilon, - init_method, - output_layer_init_method, - ) - - hidden_states_context = torch.randn( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - attention_mask = None - if use_mask and attn_mask_type != "causal": - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones( - batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision - ) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - - encoder_output = None - - if attention_type == "cross": - encoder_output = torch.randn( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - - fp8_str = "_fp8" if use_fp8 else "" - dtype_str = dtype2str(precision) - attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" - fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - input_ln_str = "_input-ln" if input_layernorm else "" - fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" - - model = te.MultiheadAttention( - *attention_args, - attn_mask_type=attn_mask_type, - params_dtype=precision, - return_layernorm_output=return_layernorm_output, - input_layernorm=input_layernorm, - attention_type=attention_type, - fuse_qkv_params=fuse_qkv_params, - return_bias=True, - ).to(device="cuda") - - inp_context = (hidden_states_context, attention_mask, encoder_output) - input_names = ["hidden_states", "attention_mask", "encoder_output"] - output_names = ["attention_output", "attention_bias"] - do_export( - model, - inp_context, - fname, - use_fp8, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "hidden_states": {0: "seq", 1: "bs"}, - "attention_output": {0: "seq", 1: "bs"}, - }, - ) - te_outputs = te_infer(model, inp_context, is_fp8=use_fp8) - serialize_inputs_outputs( - fname, inp_context, te_outputs, input_names=input_names, output_names=output_names - ) - if precision in (torch.bfloat16,): - return - - if not use_fp8: - validate_result( - fname, - inp_context, - model, - atol=1e-3, - input_names=input_names, - output_names=output_names, - te_outputs=te_outputs, - ) - else: - validate_result( - fname, - inp_context, - model, - atol=1e-2, - is_fp8=use_fp8, - input_names=input_names, - output_names=output_names, - allow_cnt_errors=3, - te_outputs=te_outputs, - ) - - # In GPT generative phase (inference) the input sequence is smaller than the maximum - # allowed sequence length and we want to test this condition. - # Pretend that we're in generative phase when it makes sense (causal mask and self-attention). - is_generative_phase = attn_mask_type == "causal" and attention_type == "self" - if is_generative_phase: - seq_len_offset = 8 - hidden_states_generative = torch.randn( - sequence_length - seq_len_offset, - batch_size, - hidden_size, - dtype=precision, - device="cuda", - ) - inp_generative = (hidden_states_generative, attention_mask, encoder_output) - if not use_fp8: - validate_result( - fname, - inp_generative, - model, - atol=1e-3, - input_names=input_names, - output_names=output_names, - ) - else: - validate_result( - fname, - inp_generative, - model, - atol=1e-2, - is_fp8=use_fp8, - input_names=input_names, - output_names=output_names, - allow_cnt_errors=3, - ) - - -@pytest.mark.parametrize("use_fp8", [False, True]) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize( - "output_layernorm", - [ - # True, # TO DO: handle this - False - ], -) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("fuse_qkv_params", [False, True]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -def test_export_transformer_layer( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - use_mask: bool, - attn_mask_type: str, - output_layernorm: bool, - precision: torch.dtype, - fuse_qkv_params: bool, - zero_centered_gamma: bool, - activation: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Layer configuration - hidden_size = 64 - sequence_length = 128 - batch_size = 1 - ffn_hidden_size = 256 - num_attention_heads = 4 - - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - input_names = ["input", "attention_mask"] - attention_mask = None - if use_mask and attn_mask_type != "causal": - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones( - batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision - ) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (input_tensor, attention_mask) - - fp8_str = "_fp8" if use_fp8 else "" - fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" - high_prec_str = dtype2str(precision) - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx" - - model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attention_heads, - self_attn_mask_type=attn_mask_type, - output_layernorm=output_layernorm, - params_dtype=precision, - fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - ).to(device="cuda") - do_export(model, inp, fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16,): - return - atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("use_fp8", [True]) -@pytest.mark.parametrize("ln_scale_factor", [448 * 2]) -@pytest.mark.parametrize( - "gemm_scale_factors", - [ - ( - 224, - 224, - ), - ], -) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -def test_export_gemm_layernorm( - seed_default_rng, - use_fp8: bool, - ln_scale_factor: float, - gemm_scale_factors: Tuple[float, float], - precision: torch.dtype, - zero_centered_gamma: bool, -): - """This is a regression test for testing that all LN inputs have the same type. - - The test sets up GEMM with FP32 output which feeds into an LN that is configured - with FP16 or BF16 weights and bias. - """ - out_features = 128 - hidden_size = 128 - in_features = 128 - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - class TestFP8_GemmLayernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") - self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(ln_scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - self.gemm = FP8GemmModule( - precision, - use_bias=False, - gelu=False, - scale_factors=gemm_scale_factors, - hidden_size=hidden_size, - out_features=out_features, - ) - - def forward(self, inp, weight): - x = self.gemm(inp, weight) - x = texcpp.layernorm_fwd_fp8_inf( - x, - self.weight, - self.bias, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - x = cast_from_fp8( - x, - self.meta, - self.fp8_tensor, - self.fp8_type, - tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16, - ) - return x - - inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") - weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") - model = TestFP8_GemmLayernorm() - high_prec_str = dtype2str(precision) - fp8_str = f"_fp8" if use_fp8 else "" - fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" - input_names = ["input", "weight"] - do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision not in (torch.bfloat16,): - validate_result( - fname, - (inp, weight), - model, - atol=5e-2, - is_fp8=use_fp8, - allow_cnt_errors=2, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@skip_FP8 -@pytest.mark.parametrize("use_fp8", [True, False]) -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [True]) -def test_export_gpt_generation( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - precision: torch.dtype, - zero_centered_gamma: bool, -): - """Test that the ONNX model can correctly handle inputs with different shapes and that - the attention mask it adjusted on-the-fly to different sequence lengths. - """ - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Layer configuration - hidden_size = 64 - sequence_length = 128 - batch_size = 1 - ffn_hidden_size = 256 - num_attention_heads = 4 - attention_mask = None - use_mask = True - attn_mask_type = "causal" - fuse_qkv_params = True - output_layernorm = False - - fp8_str = "_fp8" if use_fp8 else "" - fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" - high_prec_str = dtype2str(precision) - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx" - - model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attention_heads, - self_attn_mask_type=attn_mask_type, - output_layernorm=output_layernorm, - params_dtype=precision, - fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, - ).to(device="cuda") - - # "Context phase": use full input sequence length - input_names = ["input"] - output_names = ["output"] - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - inp = (input_tensor,) - do_export( - model, - inp, - fname, - use_fp8, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "input": {0: "seq", 1: "bs"}, - "output": {0: "seq", 1: "bs"}, - }, - ) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs( - fname, inp, te_outputs, input_names=input_names, output_names=output_names - ) - if precision not in (torch.bfloat16,): - validate_result( - fname, - inp, - model, - atol=6e-3, - is_fp8=use_fp8, - input_names=input_names, - te_outputs=te_outputs, - ) - - # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8. - sequence_length = 1 if not use_fp8 else 8 - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - inp = (input_tensor, attention_mask) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision not in (torch.bfloat16,): - validate_result( - fname, - inp, - model, - atol=6e-3, - is_fp8=use_fp8, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize("enabled", [True, False]) -def test_export_ctx_manager(enabled): - assert is_in_onnx_export_mode() == False - with te.onnx_export(enabled): - assert is_in_onnx_export_mode() == enabled - assert is_in_onnx_export_mode() == False diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index c29c01b433..35c6266a3f 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -15,7 +15,7 @@ ) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer import transformer_engine_torch as tex @@ -246,20 +246,28 @@ def _test_permutation_index_map( unpermute_bwd_input = torch.rand( size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" ) - - permute_fwd_input = Float8Tensor.to_float8( - permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - permute_bwd_input = Float8Tensor.to_float8( - permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - unpermute_bwd_input = Float8Tensor.to_float8( - unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _unpermute_bwd_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) + permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) + unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize().to(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize().to(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize().to(torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -333,10 +341,10 @@ def _test_permutation_index_map( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.from_float8(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) - te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + te_permute_output_ = te_permute_output.dequantize().to(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize().to(torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize().to(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize().to(torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 646dea552e..dcac5f1500 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -15,6 +15,7 @@ _amax_and_scale_update, get_default_fp8_recipe, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops import transformer_engine_torch as tex @@ -64,17 +65,17 @@ def test_fp8_scale_update_with_linear_module( forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) amax_history_forward = fp8_meta[forward_key].amax_history scale_forward = fp8_meta[forward_key].scale - scale_inv_forward = fp8_meta[forward_key].scale_inv + # scale_inv_forward = fp8_meta[forward_key].scale_inv backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) amax_history_backward = fp8_meta[backward_key].amax_history scale_backward = fp8_meta[backward_key].scale - scale_inv_backward = fp8_meta[backward_key].scale_inv + # scale_inv_backward = fp8_meta[backward_key].scale_inv # Tweak amax history and scaling factors amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) amax_history_forward[0, :].zero_() scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) - scale_inv_forward.copy_(torch.reciprocal(scale_forward)) + # scale_inv_forward.copy_(torch.reciprocal(scale_forward)) amax_history_backward[0, :].zero_() # Expected amax history after update @@ -100,11 +101,11 @@ def test_fp8_scale_update_with_linear_module( raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) - ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) update_weight_amax = is_first_microbatch is None or is_first_microbatch - if not update_weight_amax: - ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) - ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + # if not update_weight_amax: + # ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) + # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Perform forward, backward, and optimizer steps to update fp8_meta with te.fp8_autocast(enabled=True, fp8_recipe=recipe): @@ -133,8 +134,8 @@ def test_fp8_scale_update_with_linear_module( raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) - ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) - ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Check that scale and scale inverse match expected values # Note: scale and scale inverse are only updated when amax is updated @@ -142,27 +143,15 @@ def test_fp8_scale_update_with_linear_module( scale_forward[0], ref_scale_forward[0], ) - torch.testing.assert_close( - scale_inv_forward[0], - ref_scale_inv_forward[0], - ) if update_weight_amax: torch.testing.assert_close( scale_forward[1], ref_scale_forward[1], ) - torch.testing.assert_close( - scale_inv_forward[1], - ref_scale_inv_forward[1], - ) torch.testing.assert_close( scale_backward[0], ref_scale_backward[0], ) - torch.testing.assert_close( - scale_inv_backward[0], - ref_scale_inv_backward[0], - ) @pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @@ -180,12 +169,23 @@ def test_fp8_scale_update_with_linear_fuser_op( # Construct linear op op = te_ops.BasicLinear(in_shape[-1], in_shape[-1]) - # Get FP8 meta tensors + # FP8 recipe forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - x_fp8_meta = op.get_fp8_meta("input")[forward_key] - w_fp8_meta = op.get_fp8_meta("param")[forward_key] - dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] + fp8_format = transformer_engine.common.recipe.Format.HYBRID + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + interval=1, + fp8_format=fp8_format, + amax_history_len=amax_history_len, + amax_compute_algo=amax_compute_algo, + ) + + # Get FP8 meta tensors + with te.fp8_autocast(fp8_recipe=recipe): + x_fp8_meta = op.get_quantizer("forward", 0) + w_fp8_meta = op.get_quantizer("forward", 1) + dy_fp8_meta = op.get_quantizer("backward", 0) # Perform training steps x_history = [] @@ -214,14 +214,6 @@ def test_fp8_scale_update_with_linear_fuser_op( op.weight.fill_(w_history[-1]) # Forward and backward pass - fp8_format = transformer_engine.common.recipe.Format.HYBRID - recipe = transformer_engine.common.recipe.DelayedScaling( - margin=margin, - interval=1, - fp8_format=fp8_format, - amax_history_len=amax_history_len, - amax_compute_algo=amax_compute_algo, - ) with te.fp8_autocast(fp8_recipe=recipe): y = op(x) y.backward(dy) @@ -247,7 +239,7 @@ def check_amax_history( ) def check_scale( - fp8_meta: dict, + quantizer: Float8Quantizer, ref_amax_history: Iterable[float], stage: str, ): @@ -272,18 +264,11 @@ def check_scale( # Check values in FP8 meta tensors torch.testing.assert_close( - fp8_meta.scale.item(), + quantizer.scale.item(), ref_scale, ) - torch.testing.assert_close( - fp8_meta.scale_inv.item(), - 1 / ref_scale, - ) # Check that results match expected values - check_amax_history(x_fp8_meta, x_history) - check_amax_history(w_fp8_meta, w_history) - check_amax_history(dy_fp8_meta, dy_history) check_scale(x_fp8_meta, x_history, "forward") check_scale(w_fp8_meta, w_history, "forward") check_scale(dy_fp8_meta, dy_history, "backward") @@ -369,7 +354,6 @@ def setup_fp8_meta(): fp8_meta[forward_key].amax_history.clone().view(-1), [fp8_meta[forward_key].amax_history], [fp8_meta[forward_key].scale], - [fp8_meta[forward_key].scale_inv], recipe.amax_compute_algo, fp8_dtype, recipe.margin, @@ -378,12 +362,8 @@ def setup_fp8_meta(): _amax_and_scale_update( fp8_meta[forward_key].amax_history, fp8_meta[forward_key].scale, - fp8_meta[forward_key].scale_inv, fp8_max, recipe, ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) - torch.testing.assert_close( - fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale) - ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index daf8506593..d3bf34943d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -8,7 +8,6 @@ import torch import pytest -import io import os from transformer_engine.pytorch.fp8 import ( @@ -34,19 +33,22 @@ ) from transformer_engine.common import recipe import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import ( - gemm, - fp8_gemm, - gelu, - cast_to_fp8, - cast_from_fp8, -) +from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from test_onnx_export import create_meta +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from test_numerics import reset_rng_states, dtype_tols -# Only run FP8 tests on H100. +# Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() + + +def create_meta(scale_factor: float, size: int = 1): + meta = tex.FP8TensorMeta() + meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") + meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor + meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor + return meta def custom_amax_to_scale( @@ -96,13 +98,9 @@ def is_fp8_supported(self): fp8_recipes = [ None, # Handles non-FP8 case + recipe.MXFP8BlockScaling(), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), - recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.E4M3, - override_linear_precision=(False, False, True), - ), recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.E4M3, @@ -136,7 +134,7 @@ def is_fp8_supported(self): all_boolean = [True, False] batch_sizes_with_zero = [0, 1, 2] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"] +all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -236,6 +234,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): torch.cuda.synchronize() assert te_out.dtype == dtype, "AMP wrong output type." + assert te_inp_hidden_states.grad is not None, "Gradient should not be empty" assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type." for name, p in block.named_parameters(): if p.requires_grad: @@ -272,11 +271,14 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci loss.backward() torch.cuda.synchronize() + failed_grads = [] for name, p in block.named_parameters(): if "layer_norm_weight" in name: continue elif "weight" in name and p.requires_grad: - assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated." + if not torch.count_nonzero(p.main_grad) > 0: + failed_grads.append(name) + assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): @@ -411,6 +413,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) torch.cuda.synchronize() assert te_out.dtype == dtype, "AMP wrong output type." + assert te_inp.grad is not None, "Gradient should not be empty" assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type." for name, p in block.named_parameters(): if p.requires_grad: @@ -445,6 +448,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -474,6 +479,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -504,11 +511,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params): + with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_linear = Linear( config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -539,6 +548,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -587,6 +598,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -652,6 +665,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -709,6 +724,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -764,6 +781,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -797,6 +816,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -833,6 +854,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -872,6 +895,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -912,6 +937,8 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -962,7 +989,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): inp = torch.reshape(scratchpad[offset:-offset], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) - _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace()) + _ = general_gemm(A=weight, B=inp, workspace=get_workspace()) torch.cuda.synchronize() @@ -971,35 +998,24 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) def test_sanity_fp8_gemm_with_unalignment(N, datatype): offset = 16 - scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype) + scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype) - fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT + scales = torch.ones(1).cuda().squeeze() + amaxes = torch.ones(1).cuda().squeeze() + dtype = tex.DType.kFloat8E4M3 + fp8_quantizer = Float8Quantizer(scales, amaxes, dtype) - nb_inp_scales, nb_weight_scales = 1, N - scale_factor = 1.0 - meta_inp = create_meta(scale_factor, nb_inp_scales) - meta_weight = create_meta(scale_factor, nb_weight_scales) - inp_type = tex.DType.kFloat8E4M3 - weights_type = tex.DType.kFloat8E4M3 outp_type = datatype - scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type) - inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N)) - weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N)) - _, _ = fp8_gemm( + scratchpad_fp8 = fp8_quantizer(scratchpad) + inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N)) + weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N)) + general_gemm( weight_fp8, - meta_weight.scale_inv, - fp8_tensor_weight, - inp_type, inp_fp8, - meta_inp.scale_inv, - fp8_tensor_inp, - weights_type, - outp_type, get_workspace(), + outp_type, bias=None, - use_bias=False, use_split_accumulator=False, ) torch.cuda.synchronize() @@ -1062,13 +1078,15 @@ def get_model(dtype, config): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_enabled): + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, fuse_qkv_params=True, params_dtype=dtype, device="cuda", diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py deleted file mode 100644 index 46ce33becc..0000000000 --- a/tests/pytorch/test_torch_save_load.py +++ /dev/null @@ -1,474 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -This file contains tests for saving and loading TransformerEngine torch checkpoints. - -The purpose of this test is to validate the TransformerEngine hooks for saving FP8 metadata -in torch checkpoints, which are called as part of torch.save() and torch.load(). -The test verifies the values of FP8 metadata object after saving and loading a checkpoint -are identical to the original values. -""" - -import io -import tempfile -from typing import Iterable, Union - -import pytest -import torch -import transformer_engine.common -import transformer_engine.pytorch as te -import transformer_engine.pytorch.ops as te_ops -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() - - -def init_meta(size: int = 1): - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - return meta - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("scale_fwd", [224, 112, 66]) -@pytest.mark.parametrize("scale_bwd", [448, 33]) -@pytest.mark.parametrize("history_fwd", [1.23, 4.56]) -@pytest.mark.parametrize("history_bwd", [2.34, 5.67]) -def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd): - - tmp_filename = tempfile.NamedTemporaryFile().name - - precision = torch.float32 - - class Test_TE_Export(TransformerEngineBaseModule): - def __init__(self, precision, use_bias): - super().__init__() - self.use_bias = use_bias - self.precision = precision - - self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT - nb_inp_scales = nb_weight_scales = 1 - self.meta_inp = init_meta(nb_inp_scales) - self.meta_weight = init_meta(nb_weight_scales) - - bias_size = nb_weight_scales - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - - self.inp_type = tex.DType.kFloat8E4M3 - self.weights_type = tex.DType.kFloat8E4M3 - self.outp_type = precision - - def get_fp8_weights_scratchpad(self, is_first_microbatch): - raise RuntimeError( - "Method get_fp8_weights_scratchpad is dummy and should not be invoked." - ) - - def forward(self, inp, weight): - inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) - - weight_fp8 = cast_to_fp8( - weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type - ) - - ret = fp8_gemm( - weight_fp8, - self.meta_weight.scale_inv, - self.fp8_tensor_weight, - self.inp_type, - inp_fp8, - self.meta_inp.scale_inv, - self.fp8_tensor_inp, - self.weights_type, - self.outp_type, - get_workspace(), - bias=self.bias, - use_bias=self.use_bias, - use_split_accumulator=False, - ) - return ret - - model_in = Test_TE_Export(precision, True) - with te.fp8_autocast(enabled=True): - model_in.init_fp8_metadata() - # scaling fwd - model_in.fp8_meta["scaling_fwd"].scale = ( - torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd - ) - model_in.fp8_meta["scaling_fwd"].scale_inv = ( - torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd - ) - model_in.fp8_meta["scaling_fwd"].amax_history = ( - torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd - ) - # scaling bwd - model_in.fp8_meta["scaling_bwd"].scale = ( - torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd - ) - model_in.fp8_meta["scaling_bwd"].scale_inv = ( - torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd - ) - model_in.fp8_meta["scaling_bwd"].amax_history = ( - torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd - ) - - torch.save(model_in.state_dict(), tmp_filename) - - model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename, weights_only=False)) - model_out.eval() - - # scaling fwd - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale - ) - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv - ) - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].amax_history, - model_out.fp8_meta["scaling_fwd"].amax_history, - ) - # scaling bwd - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale - ) - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv - ) - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].amax_history, - model_out.fp8_meta["scaling_bwd"].amax_history, - ) - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("save_fp8_model", [True, False]) -@pytest.mark.parametrize("load_fp8_model", [True, False]) -def test_fp8_model_checkpoint( - save_fp8_model: bool, - load_fp8_model: bool, - dims: Iterable[int] = [32, 32], - dtype: torch.dtype = torch.float32, - device: Union[torch.device, str] = "cuda", -): - - # Construct model - dims = list(dims) - hidden_dim = dims[-1] - with te.fp8_model_init(enabled=save_fp8_model): - model = te.Linear( - hidden_dim, - hidden_dim, - bias=False, - params_dtype=dtype, - device=device, - ) - # Keep track of model output - x = torch.randn(dims, dtype=dtype, device=device) - with te.fp8_autocast(): - y_ref = model(x.detach().clone()).detach().clone() - - fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}} - with te.fp8_autocast(), torch.no_grad(): - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5 - fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal() - fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5 - fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal() - fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"]) - fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"]) - fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) - fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) - del fp8_meta_fwd, fp8_meta_bwd - - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. - # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. - # It is essential for these values to be equal, so setting scale_inv only in the model metadata is insufficient. - model.weight.data.copy_(model.weight.float().cuda()) - # After copying, the tensor computes the meta scale_inv based on the amax history; we then reset these values. - model.fp8_meta["scaling_fwd"].scale = fp8_meta_fwd_ref["scale"] - model.fp8_meta["scaling_fwd"].scale_inv = fp8_meta_fwd_ref["scale_inv"] - - # Keep track of weights and FP8 scaling factors - weight_ref = model.weight.float().detach().clone() - - # Save checkpoint - byte_stream = io.BytesIO() - torch.save(model.state_dict(), byte_stream) - model_bytes = byte_stream.getvalue() - del byte_stream - - # Disturb and destroy model - with torch.no_grad(): - model.weight.zero_() - model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234} - del model - - # Construct new model - with te.fp8_model_init(enabled=load_fp8_model): - model = te.Linear( - hidden_dim, - hidden_dim, - bias=False, - params_dtype=dtype, - device=device, - ) - - # Make sure new model does not match saved model - tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 - with pytest.raises(AssertionError): - torch.testing.assert_close(model.weight, weight_ref, **tols) - with te.fp8_autocast(): - model.init_fp8_metadata() - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) - with te.fp8_autocast(): - y = model(x.detach().clone()) - with pytest.raises(AssertionError): - torch.testing.assert_close(y, y_ref, **tols) - - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # When save_fp8_model=True, we load a model with weights in high precision, - # which does not include _scale_inv, - # but has the fp8 scaling factor in the meta data. This scenario can occur - # when using te.fp8_autocast(enabled=False, calibrating=True). - # - # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, - # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior - # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, - # to load the fp8 metadata before loading tensors. - # - # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) - del model_bytes - - # Check that loaded model matches saved model - torch.testing.assert_close(model.weight, weight_ref, **tols) - with te.fp8_autocast(): - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) - torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) - torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) - torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) - with te.fp8_autocast(): - y = model(x.detach().clone()) - torch.testing.assert_close(y, y_ref, **tols) - - if load_fp8_model: - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # We need to ensure that the tensor's scale_inv parameter matches its meta data. - # This is crucial to avoid confusion about which value is correct. - meta_index = model.weight._fp8_meta_index - torch.testing.assert_close( - model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() - ) - - -@pytest.mark.parametrize("fp8", (False, True)) -@pytest.mark.parametrize("save_fp8_model", (False, True)) -@pytest.mark.parametrize("load_fp8_model", (False, True)) -def test_sequential_model( - *, - in_shape: Iterable[int] = (16, 16), - dtype: torch.dtype = torch.float32, - device: torch.device = "cuda", - save_steps: int = 2, - load_steps: int = 2, - fp8: bool, - save_fp8_model: bool, - load_fp8_model: bool, -) -> None: - - # Skip invalid configurations - if fp8 or save_fp8_model or load_fp8_model: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - - # FP8 recipe - margin = 2 - fp8_format = transformer_engine.common.recipe.Format.E4M3 - recipe = transformer_engine.common.recipe.DelayedScaling( - margin=margin, - fp8_format=fp8_format, - amax_history_len=8, - amax_compute_algo="max", - ) - - # Construct model to save to checkpoint - with te.fp8_model_init(enabled=save_fp8_model): - model = te_ops.Sequential( - te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), - ) - with torch.no_grad(): - torch.rand(model[0].weight.size(), out=model[0].weight) - torch.rand(model[0].bias.size(), out=model[0].bias) - - # Synthetic data - xs_ref = [ - torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) - ] - dys_ref = [ - torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) - ] - - def train_step( - model: te_ops.Sequential, - x: torch.Tensor, - dy: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Helper function to perform training step""" - x = x.detach().clone().requires_grad_() - dy = dy.detach().clone() - with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe): - y = model(x) - y.backward(dy) - with torch.no_grad(): - for param in model.parameters(): - param += 0.125 - return ( - y.detach().clone(), - x.grad.detach().clone(), - model[0].weight.detach().float().clone(), - ) - - # Initial training steps with saved model - ys_ref = [] - dxs_ref = [] - ws_ref = [] - for step in range(save_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - ys_ref.append(y) - dxs_ref.append(dx) - ws_ref.append(w) - - # Keep track of FP8 metadata if needed - fp8_meta_ref = dict(input={}, param={}, grad_output={}) - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - m_ref["amax"] = m_model.amax_history.detach().clone() - m_ref["scale"] = m_model.scale.detach().clone() - m_ref["scale_inv"] = m_model.scale_inv.detach().clone() - del m_model, m_ref - - # Save checkpoint - byte_stream = io.BytesIO() - torch.save(model.state_dict(), byte_stream) - model_bytes = byte_stream.getvalue() - del byte_stream - - # More training steps with saved model - for step in range(save_steps, save_steps + load_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - ys_ref.append(y) - dxs_ref.append(dx) - ws_ref.append(w) - - # Disturb and destroy model - with torch.no_grad(): - for param in model.parameters(): - param.zero_() - model[0].basic_ops[0]._fp8_metas = None - del model - - # Construct new model to load from checkpoint - with te.fp8_model_init(enabled=load_fp8_model): - model = te_ops.Sequential( - te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), - ) - - # Tolerances for numerical checks - tols = {} - if fp8 or save_fp8_model or load_fp8_model: - tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 - exact_tols = dict(rtol=0, atol=0) - - # Training steps with dummy data - for step in range(save_steps): - y, dx, w = train_step( - model, - torch.zeros_like(xs_ref[step]), - torch.zeros_like(dys_ref[step]), - ) - - # Make sure results don't match saved model - with pytest.raises(AssertionError): - torch.testing.assert_close(y, ys_ref[step], **tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(dx, dxs_ref[step], **tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(w, ws_ref[step], **tols) - - # Make sure new model's FP8 metadata doesn't match saved model - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) - - # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) - del model_bytes - - # Check that new model's FP8 metadata matches saved model - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) - torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) - torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) - - # More training steps with loaded model - for step in range(save_steps, save_steps + load_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - torch.testing.assert_close(y, ys_ref[step], **tols) - torch.testing.assert_close(dx, dxs_ref[step], **tols) - torch.testing.assert_close(w, ws_ref[step], **tols) diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index d97d9653e6..8b80364a3d 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -19,19 +19,9 @@ except (ImportError, StopIteration) as e: pass -try: - from . import paddle -except (ImportError, StopIteration) as e: - pass - try: import transformer_engine_jax except ImportError: pass -try: - import transformer_engine_paddle -except ImportError: - pass - __version__ = str(metadata.version("transformer_engine")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3afddcc48d..ed59153954 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -6,13 +6,17 @@ cmake_minimum_required(VERSION 3.21) # Language options if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + endif() endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") endif() # Hide non-necessary symbols in shared object. @@ -78,6 +82,7 @@ list(APPEND transformer_engine_SOURCES util/cuda_runtime.cpp util/rtc.cpp util/system.cpp + swizzle/swizzle.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index ddb786bd3a..708403f911 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -4,111 +4,71 @@ * See LICENSE for license information. ************************************************************************/ +/*! \file activation_template.h + * \brief Activation functions template. + */ + +#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ +#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ + #include #include #include "../common.h" +#include "../util/cast_gated_kernels.cuh" +#include "../util/cast_kernels.cuh" +#include "../util/math.h" #include "../util/vectorized_pointwise.h" namespace transformer_engine { template -void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "act_lu_input"); - CheckOutputTensor(*output, "act_lu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - const size_t tot_elts = product(input.data.shape); +void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = true; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } template -void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "dact_lu_input"); - CheckInputTensor(grad, "dact_lu_input_grad"); - CheckOutputTensor(*output, "dact_lu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match."); - const size_t tot_elts = product(input.data.shape); +void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } -template -void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); +template +void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = false; + constexpr NVTETensor grad = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], - output->data.shape[1], {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_gated_helper(grad, input, output, stream); } -template -void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); +template +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = true; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), grad.data.shape[0], grad.data.shape[1], - {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_gated_helper(grad, input, output, stream); } } // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index cb38b351e9..0cf43007a7 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -3,69 +3,58 @@ * * See LICENSE for license information. ************************************************************************/ + #include "../util/math.h" #include "./activation_template.h" void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dgelu>(grad, input, output, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dqgelu>(grad, input, output, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index 7653991819..a794b7315f 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -10,63 +10,51 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_relu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, drelu>(grad, input, output, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_srelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dsrelu>(grad, input, output, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 5a0e0ead84..8194964745 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -10,31 +10,25 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_silu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dsilu>(grad, input, output, stream); } diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 003ea9588c..d988de6f66 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,8 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#define AS_VECTOR(shape) std::vector(shape.data, shape.data + shape.ndim) + using namespace std::placeholders; namespace transformer_engine { @@ -40,8 +42,9 @@ bool ubuf_built_with_mpi() { CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm) { + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { // Initialize userbuf communicator if (!_comm_created) { if (myrank == 0) { @@ -59,9 +62,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; + if (gemm_priority == 0 && comm_priority == 0) { + transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority); + } else { + _gemm_priority = gemm_priority; + _comm_priority = comm_priority; + } for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); _stream_compute.push_back(std::move(stream)); } @@ -130,6 +139,73 @@ CommOverlapCore::~CommOverlapCore() { } } +TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, + const std::vector &chunk_shape) { + TensorWrapper chunk; + for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { + auto param_type = static_cast(param_id); + auto param = source.get_parameter(param_type); + auto param_dptr = reinterpret_cast(param.data_ptr); + auto param_dtype = static_cast(param.dtype); + auto param_shape = AS_VECTOR(param.shape); + + if (param_dptr != nullptr) { + if (param_type == NVTETensorParam::kNVTERowwiseData || + param_type == NVTETensorParam::kNVTEColumnwiseData) { + // Offset data pointer + param_dptr += chunk_offset * typeToSize(param_dtype); + param_shape = chunk_shape; + + if (param_type == NVTETensorParam::kNVTEColumnwiseData && + source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { + // Columnwise shape for non-block scaled tensors shifts the last dimension to the front + auto last_dim = param_shape.back(); + param_shape.pop_back(); + param_shape.insert(param_shape.begin(), last_dim); + } + } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && + (param_type == NVTETensorParam::kNVTERowwiseScaleInv || + param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { + // Calculate block scaling offset and size + auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? source.shape().data[0] + : source.columnwise_shape().data[0]; + auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? chunk_shape.front() + : chunk_shape.back(); + auto chunk_scale_start = chunk_offset / 32; + auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; + auto chunk_scale_size = chunk_scale_end - chunk_scale_start; + param_dptr += chunk_scale_start * typeToSize(param_dtype); + param_shape = std::vector{chunk_scale_size}; + } + + // Set chunked source parameters into the chunked tensor output + chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, + param_shape); + } + } + return chunk; +} + +TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source, + size_t chunk_offset, + const std::vector &chunk_shape) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); + + // Update chunk with offset data pointers from the communication buffer + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), + chunk.columnwise_shape()); + } + return chunk; +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ @@ -138,11 +214,14 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, - num_comm_sm, set_sm_margin, false, atomic_gemm) { + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, + atomic_gemm) { + _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", @@ -155,7 +234,8 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); } @@ -168,8 +248,8 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, @@ -196,7 +276,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); } else { @@ -221,20 +301,20 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, +void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; // Get GEMM dimensions - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); @@ -255,9 +335,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens assert(pre_gelu_out.numel() == 0); - auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, {n, m}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), @@ -269,11 +348,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, &counter_ptr[i], _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, @@ -282,11 +360,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens } } else if (_rs_kernel_type == 2) { if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, @@ -299,7 +376,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens if (_ubuf.element_size() == 1) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, _stream_comm);); } else { @@ -321,34 +398,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t input_a_chunk_size = m_chunk * k; size_t output_chunk_size = n * m_chunk; - size_t bias_chunk_size = m_chunk; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); - char *bias_chunk_ptr = reinterpret_cast(bias.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (size_t i = 0; i < _stream_compute.size(); i++) { @@ -358,39 +425,23 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { - auto input_a_chunk = - TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = - TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto bias_chunk = - TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr); - auto workspace_chunk = TensorWrapper( - workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_rs_overlap_first_gemm) { + auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[0]); for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * D.element_size(); - if (bias_chunk_ptr != nullptr) { - bias_chunk_ptr += bias_chunk_size * bias.element_size(); - } - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, - D.dtype(), D.amax(), D.scale(), nullptr); - bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, bias.dtype(), - nullptr, nullptr, nullptr); - workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -401,11 +452,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Communication chunk if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, @@ -422,12 +472,11 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, _stream_comm);); + rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, + n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, @@ -435,20 +484,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } } else { for (int i = 0; i < _num_splits; i++) { - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), - {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, - bias.dtype(), nullptr, nullptr, nullptr); - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -461,11 +502,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, @@ -473,11 +513,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - if (bias_chunk_ptr != nullptr) { - bias_chunk_ptr += bias_chunk_size * bias.element_size(); - } } } @@ -499,11 +534,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm, bool aggregate) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, - num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -552,8 +589,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); + for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + _stream_send.push_back(std::move(stream)); + } + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); } @@ -562,7 +604,22 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - cudaStreamDestroy(_stream_send); + for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); +} + +TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, + size_t chunk_id) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape())); + + // Update chunk with offset data pointers from the communication buffer + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); + } + return chunk; } /* @@ -570,12 +627,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -583,8 +638,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Get GEMM dimensions between TN and NN input layouts const size_t m = (transa) ? A.size(0) : A.size(1); - const size_t n = _ubuf.size(0); - const size_t n_chunk = n / _tp_size; + const size_t n_chunk = _ubufs[0].size(0); assert(pre_gelu_out.numel() == 0); // Get communication and GEMM output chunk sizes @@ -594,7 +648,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T void *D_buffer_ptr; int D_chunk_bytes = n_chunk * m * D.element_size(); NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); - auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), + D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); // Reset atomic counters int *counter_ptr = reinterpret_cast(_counter.dptr()); @@ -602,13 +657,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape())); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); for (int i = 0; i < _tp_size - 1; i++) { // Set the userbuffer id. Buffer under send is the input for the current @@ -649,8 +703,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T NVTE_CHECK_CUDA( cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); } @@ -674,11 +728,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -691,24 +746,20 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.dptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); } if (_aggregate) { const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + input_chunk_size *= 2; + output_chunk_size *= 2; // Initial 1X input chunk exchange between neighboring peers int send_chunk_id = _tp_id; @@ -717,11 +768,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - _stream_send); + _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; @@ -736,27 +787,15 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW recv_offset = comm_bytes * recv_chunk_id; // GEMM - char *input_b_chunk_ptr = input_b_ptr + send_offset; auto input_b_chunk = - TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = - (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -766,11 +805,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW if (i < num_steps - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, _stream_send); + next_rank, _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { @@ -778,7 +817,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); } } } else { @@ -793,24 +832,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, - D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -820,11 +849,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW if (i < _tp_size - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, _stream_send); + _next_rank, _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { @@ -832,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); } } } @@ -842,7 +871,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); @@ -851,13 +880,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, - cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -876,14 +903,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T // Atomic GEMM // Process GEMM chunks in the order that AG+GEMM places the output chunks. - auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape())); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), - stream_main); + transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, _counter.data(), stream_main); // P2P communication chunk for (int i = 1; i < _tp_size; i++) { @@ -907,10 +930,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); @@ -921,31 +943,33 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t k = A.size(1); - size_t n = B.size(0); // Get communication and GEMM input chunk sizes - size_t n_chunk = n / _tp_size; + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n_chunk = _ubufs[0].size(0); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); // Catch up the main stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + for (size_t i = 0; i < _stream_send.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0)); + } NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); @@ -954,36 +978,30 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW // GEMM and send/recv chunks for (int i = 0; i < _tp_size; i++) { // GEMM chunk + int stream_id = i % _stream_compute.size(); int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - - auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, - B.dtype(), nullptr, nullptr, B.scale_inv()); - - auto output_chunk = - TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); - char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); + use_split_accumulator, _math_sms, _stream_compute[stream_id]); if (i > 0) { // P2P communication chunk + int prev_stream_id = (i - 1) % _stream_compute.size(); int send_offset = comm_bytes * (i - 1); int recv_offset = comm_bytes * (i - 1 + _tp_size); int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - _stream_send); + _stream_send[prev_stream_id]); userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, _stream_recv); } @@ -993,8 +1011,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); @@ -1002,11 +1022,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index b2cd71f76b..735148a811 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -19,6 +19,7 @@ #include #include +#include "common/util/system.h" #include "userbuffers.h" #define MAX_THREADS 1024 diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 01b940f06a..cbeec66958 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -6,27 +6,138 @@ #include +#include + #include "./common.h" #include "./utils.cuh" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" namespace transformer_engine { namespace { __global__ void __launch_bounds__(1) - update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, - float* __restrict__ scale_inv_ptr) { + update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr, + float *__restrict__ scale_inv_ptr) { const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; reciprocal(scale_inv_ptr, scale); } } // namespace -void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { - if (t->scale_inv.dptr != nullptr) { +void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { + if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { + NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( - reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + reinterpret_cast(t->scale.dptr), + reinterpret_cast(t->scale_inv.dptr)); } } +void checkCuDriverContext(CUstream stream) { + CUcontext ctx; + const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); + switch (driver_status) { + case CUDA_SUCCESS: + break; + + case CUDA_ERROR_INVALID_CONTEXT: + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx); + break; + + default: + const char *desc_NVTE_CHECK_CUDA_DRIVER; + cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER); + NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); + } +} + +CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { + static const std::unordered_map dtypeMapping = { + {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, + {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, + {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, + {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; + return dtypeMapping.at(dtype); +} + +inline bool isPointerAligned(const void *const ptr, const int alignment) { + const uint64_t ptr_as_uint = reinterpret_cast(ptr); + return ptr_as_uint % alignment == 0; +} + +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size) { + // Get a function pointer to the cuTensorMapEncodeTiled driver API + static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() { + void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(driver_ptr); + }(); + // rank is the number of dimensions of the array + constexpr uint32_t rank = 2; + uint64_t size[rank] = {globalX, globalY}; + + // The stride is the number of bytes to traverse from the first element of one row to the next + uint64_t stride[rank - 1] = {stride_elems * type_size}; + + // The boxSize is the size of the shared memory buffer that is used as the + // source/destination of a TMA transfer + uint32_t boxSize[rank] = {shmemX, shmemY}; + + // The distance between elements in units of sizeof(element) + uint32_t elemStride[rank] = {1, 1}; + + const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); + void *dataPtr = + reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); + + constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address + NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment), + "Tensor data pointer must be 16B aligned"); + + const int TMA_needed_size = TMA_gmem_alignment / type_size; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, + "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); + + // Create the tensor descriptor. + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, // CUtensorMap *tensorMap, + tensorDataType, + rank, // cuuint32_t tensorRank, + dataPtr, // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + boxSize, // const cuuint32_t *boxDim, + elemStride, // const cuuint32_t *elementStrides, + // Interleave patterns can be used to accelerate loading of values that + // are less than 4 bytes long. + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + + // L2 Promotion can be used to widen the effect of a cache-policy to a wider + // set of L2 cache lines. + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + +bool is_supported_by_CC_100() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability >= 100; +} + } // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index d47ce472e5..ca9103532d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -7,6 +7,7 @@ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ +#include #include #include #include @@ -22,10 +23,29 @@ #include #include "./nvtx.h" +#include "./util/cuda_driver.h" #include "./util/logging.h" namespace transformer_engine { +inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { + NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", + end, " in a vector with ", shape.size(), " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} + +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + struct SimpleTensor { void *dptr; std::vector shape; @@ -33,20 +53,114 @@ struct SimpleTensor { SimpleTensor(void *dptr, const std::vector &shape, DType dtype) : dptr(dptr), shape(shape), dtype(dtype) {} + + SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT + : dptr(tensor.data_ptr), + shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), + dtype(static_cast(tensor.dtype)) {} + SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} + + operator NVTEBasicTensor() const { + const NVTEShape shape = {this->shape.data(), this->shape.size()}; + return {dptr, static_cast(dtype), shape}; + } + + int numel() const { + size_t acc = 1; + for (const auto &dim : shape) { + acc *= dim; + } + return acc; + } }; struct Tensor { SimpleTensor data; + SimpleTensor columnwise_data; SimpleTensor amax; SimpleTensor scale; SimpleTensor scale_inv; + SimpleTensor columnwise_scale_inv; + + NVTEScalingMode scaling_mode; Tensor() : data(), + columnwise_data(), amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), - scale_inv(nullptr, {1}, DType::kFloat32) {} + scale_inv(nullptr, {1}, DType::kFloat32), + columnwise_scale_inv(nullptr, {1}, DType::kFloat32), + scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + + int numel() const { + NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, + "Tensor does not hold any data!"); + size_t acc = 1; + if (data.dptr != nullptr) { + for (const auto &dim : data.shape) { + acc *= dim; + } + return acc; + } + // data is empty, use columnwise_data + for (const auto &dim : columnwise_data.shape) { + acc *= dim; + } + return acc; + } + + bool has_data() const noexcept { return data.dptr != nullptr; } + + bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } + + DType dtype() const { + if (has_data()) return data.dtype; + if (has_columnwise_data()) return columnwise_data.dtype; + // Fallback, used e.g. in workspace + return data.dtype; + } + + /*! Matrix height after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_first_dim() const { + if (!has_data() && has_columnwise_data()) { + const auto &data_shape = columnwise_data.shape; + if (data_shape.empty()) return 1; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + return product(data_shape, 1, data_shape.size()); + } else { + return product(data_shape, 0, data_shape.size() - 1); + } + } + const auto &data_shape = data.shape; + if (data_shape.empty()) return 1; + return product(data_shape, 0, data_shape.size() - 1); + } + + /*! Matrix width after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_last_dim() const { + if (!has_data() && has_columnwise_data()) { + const auto &data_shape = columnwise_data.shape; + if (data_shape.empty()) return 1; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + return data_shape.front(); + } else { + return data_shape.back(); + } + } + const auto &data_shape = data.shape; + if (data_shape.empty()) return 1; + return data_shape.back(); + } }; template @@ -62,6 +176,10 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +#if CUDA_VERSION >= 12080 +using fp8e8m0 = __nv_fp8_e8m0; +#endif +using e8m0_t = uint8_t; namespace detail { @@ -80,6 +198,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) +#if CUDA_VERSION >= 12080 +TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) +#endif #undef TRANSFORMER_ENGINE_TYPE_NAME } // namespace detail @@ -150,6 +271,10 @@ struct TypeInfo { using type = fp8e5m2; \ { __VA_ARGS__ } \ } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -181,6 +306,25 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -236,15 +380,22 @@ struct TypeInfo { NVTE_ERROR("Invalid type for 16 bit."); \ } -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline size_t product(const std::vector &shape) { - size_t ret = 1; - for (const auto &elem : shape) { - ret *= elem; +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Invalid size of the MX scaling factor."); \ + } \ } - return ret; -} + +//////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { int log2_value = 0; @@ -269,13 +420,37 @@ struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; +// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + size_t typeToSize(const DType type); +void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); bool is_fp8_dtype(const DType t); +std::string to_string(const DType type); +std::string to_string(const NVTEScalingMode &type); + +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_block_scaling(const NVTEScalingMode &mode) { + return mode != NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { + return is_tensor_scaling(mode); +} + +inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated @@ -286,6 +461,20 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); +void checkCuDriverContext(CUstream stream); + +CUtensorMapDataType get_CUtensorMapDataType(DType dtype); + +inline bool isPointerAligned(const void *const ptr, const int alignment); + +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size); + +bool is_supported_by_CC_100(); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5d3e1d6097..01151a50db 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -93,17 +93,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && - (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && - (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && - (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && - (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || - ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && - ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || - (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) && + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 8.9: t3hd, max_s=512, d=64, padding + ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; @@ -135,8 +149,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset) { flag_m512 = true; } - // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - if ( // architecture + if ( + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging + // special conditions for blackwell + // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 + !(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && // sequence length @@ -218,9 +236,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + // TODO(cyang): fix bug for BRCM + cross-attention on sm100 + (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700)))) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700))))) && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && // check 64-bit ragged offset support diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 20467af663..36ff5291a8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -227,7 +227,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_attn_scale(attn_scale); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_options.set_sliding_window_length(window_size_left + 1); + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } sdpa_options.set_alibi_mask(is_alibi); @@ -457,8 +457,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if (is_ragged && cudnn_runtime_version >= 90600) { @@ -667,7 +665,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_backward_options.set_sliding_window_length(window_size_left + 1); + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } if (cudnn_runtime_version >= 90000) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 0044a94b2f..b4424d9bf6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1670,8 +1670,6 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); try { FADescriptor_v1 descriptor{b, @@ -1798,36 +1796,33 @@ void fused_attn_fp8_fwd_impl_v1( // sdpa_options.set_bias(bias); // } - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); @@ -1919,29 +1914,28 @@ void fused_attn_fp8_fwd_impl_v1( {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } + /* if (is_bias) { + variant_pack[bias] = devPtrBias; + } */ + + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1974,8 +1968,6 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); try { FADescriptor_v1 descriptor{b, @@ -2151,36 +2143,35 @@ void fused_attn_fp8_bwd_impl_v1( // } // } - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_backward_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_backward_options.set_padding_mask(is_padding) + .set_seq_len_q(seq_q) + .set_seq_len_kv(seq_kv); + } - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_backward_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, @@ -2308,34 +2299,32 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // if ((bias_b == 1) && (bias_h == h)) { - // variant_pack[dBias] = devPtrdBias; - // } else { - // variant_pack[dBias] = nullptr; - // } - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } + /* if (is_bias) { + variant_pack[bias] = devPtrBias; + if ((bias_b == 1) && (bias_h == h)) { + variant_pack[dBias] = devPtrdBias; + } else { + variant_pack[dBias] = nullptr; + } + } */ + + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ef7cdc0af9..52fa89b914 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -15,6 +15,7 @@ #include "../common.h" #include "../util/logging.h" +#include "common/util/cuda_runtime.h" namespace { @@ -46,6 +47,95 @@ uint32_t _getAlignment(uintptr_t address) { } } +struct GemmParam { + void *A; + void *B; + cublasOperation_t transA; + cublasOperation_t transB; + transformer_engine::DType Atype; + transformer_engine::DType Btype; + void *A_scale_inv; + void *B_scale_inv; + int lda; + int ldb; + + GemmParam(cublasOperation_t transA, cublasOperation_t transB) + : A(nullptr), + B(nullptr), + transA(transA), + transB(transB), + Atype(transformer_engine::DType::kNumTypes), + Btype(transformer_engine::DType::kNumTypes), + A_scale_inv(nullptr), + B_scale_inv(nullptr), + lda(0), + ldb(0) {} +}; + +GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, + const transformer_engine::Tensor &B, const cublasOperation_t transB, + const int k, const int lda, const int ldb) { + using namespace transformer_engine; + NVTE_CHECK(A.scaling_mode == B.scaling_mode, + "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); + NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); + GemmParam ret(transA, transB); + + ret.lda = lda; + ret.ldb = ldb; + + if (is_tensor_scaling(A.scaling_mode)) { + ret.A = A.data.dptr; + ret.A_scale_inv = A.scale_inv.dptr; + if (transA == CUBLAS_OP_T) { + ret.Atype = A.data.dtype; + } else { + ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; + if (is_fp8_dtype(ret.Atype)) { + int arch = cuda::sm_arch(cuda::current_device()); + if (arch < 100) { + // Hopper and Ada - we need to use columnwise_data and change transA + NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } + } + } + ret.B = B.data.dptr; + ret.B_scale_inv = B.scale_inv.dptr; + if (transB == CUBLAS_OP_T) { + ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; + if (is_fp8_dtype(ret.Btype)) { + int arch = cuda::sm_arch(cuda::current_device()); + if (arch < 100) { + // Hopper and Ada - we need to use columnwise_data and change transA + NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } + } + } else { + ret.Btype = B.data.dtype; + } + } else { + // If not tensor scaling (which includes also high precision types), we need to + // use the proper version of data + // We leave the transA/B values as is, since Blackwell supports transposes + ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; + ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; + ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + } + return ret; +} + } // namespace namespace transformer_engine { @@ -56,10 +146,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { - void *A = inputA->data.dptr; - void *A_scale_inverse = inputA->scale_inv.dptr; - void *B = inputB->data.dptr; - void *B_scale_inverse = inputB->scale_inv.dptr; + // Return immediately if GEMM is trivial + if (m <= 0 || n <= 0) { + return; + } + NVTE_CHECK(k > 0); + + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -72,15 +165,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, counter = inputCounter->data.dptr; } const bool gelu = pre_gelu_out != nullptr; - const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); - const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype); - const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype); + const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype); + + const cudaDataType_t A_type = get_cuda_dtype(param.Atype); + const cudaDataType_t B_type = get_cuda_dtype(param.Btype); const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype); const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype); - NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, + NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); - NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, + NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); // check consistency of arguments: @@ -117,17 +211,17 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k, - transa == CUBLAS_OP_N ? k : m, lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n, - transb == CUBLAS_OP_N ? n : k, ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &transa, sizeof(transa))); + ¶m.transA, sizeof(param.transA))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &transb, sizeof(transb))); + ¶m.transB, sizeof(param.transB))); // Set math SM count if (math_sm_count != 0) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -143,12 +237,53 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, - &A_scale_inverse, sizeof(A_scale_inverse))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - &B_scale_inverse, sizeof(B_scale_inverse))); + + // Scaling factors. +#if CUDA_VERSION >= 12080 + cublasLtMatmulMatrixScale_t scaling_mode; +#endif + if ((is_delayed_tensor_scaling(inputA->scaling_mode) && + is_delayed_tensor_scaling(inputB->scaling_mode))) { + void *A_scale_inverse = param.A_scale_inv; + void *B_scale_inverse = param.B_scale_inv; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); +#if CUDA_VERSION >= 12080 + scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. + // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. + if (cublasLtGetVersion() <= 120803) { + const int64_t dummy_a_vec_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, + sizeof(dummy_a_vec_stride))); + } +#endif + } else { + NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + + to_string(inputB->scaling_mode) + "."); + } + +#if CUDA_VERSION >= 12080 + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); +#endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output C = nullptr; @@ -156,8 +291,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); - // For FP8 output, cuBLAS requires C_type to be same as bias_type - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd)); +#if CUDA_VERSION >= 12080 + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); +#endif + // For FP8 output, cuBLAS requires C_type to match bias_type and + // be FP16/BF16 + const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF; + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd)); } else { NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd)); } @@ -235,8 +376,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); - const auto A_alignment = _getAlignment(reinterpret_cast(A)); - const auto B_alignment = _getAlignment(reinterpret_cast(B)); + const auto A_alignment = _getAlignment(reinterpret_cast(param.A)); + const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -260,8 +401,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, static_cast(&one), /* alpha */ - A, /* A */ - Adesc, B, /* B */ + param.A, /* A */ + Adesc, param.B, /* B */ Bdesc, static_cast(&beta), /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -270,7 +411,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor - if (is_fp8_dtype(outputD->data.dtype)) { + // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. + // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be + // calculated here. + if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) { update_tensor_scale_inv(outputD, stream); } @@ -309,9 +453,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; + const size_t A0 = inputA->flat_first_dim(); + const size_t A1 = inputA->flat_last_dim(); + const size_t B0 = inputB->flat_first_dim(); + const size_t B1 = inputB->flat_last_dim(); + + const int m = transa ? A0 : A1; + const int k = transa ? A1 : A0; + const int n = transb ? B1 : B0; int lda, ldb, ldd; if (transa && !transb) { // TN lda = k; @@ -357,6 +506,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = reinterpret_cast(counter); Tensor *wspace = reinterpret_cast(workspace); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && + is_delayed_tensor_scaling(inputB->scaling_mode), + "Atomic GEMM only supports delayed scaling."); + const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 53a66c25b5..49029ed588 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -19,7 +19,9 @@ extern "C" { /* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ -/*! \brief Compute activation of the input. +/*! \brief Computes activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor for activation. * \param[in,out] output Output tensor. @@ -39,17 +41,59 @@ enum class NVTE_Activation_Type { SREGLU, }; +/*! \brief Computes the GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute activation gradient. +/*! \brief Computes the GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] grad Incoming gradient. * \param[in] input Input tensor for activation. @@ -59,19 +103,57 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute gated activation of the input. +/*! \brief Computes the gated GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor of shape [N, H * 2]. * \param[in,out] output Output tensor of shape [N, H]. @@ -80,15 +162,54 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu */ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute gated activation gradient. +/*! \brief Computes the gated GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * * \param[in] grad Incoming gradient of shape [N, H]. * \param[in] input Forward input tensor of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2]. @@ -97,15 +218,51 @@ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 88a7dec251..d57975b2f4 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -5,7 +5,7 @@ ************************************************************************/ /*! \file cast.h - * \brief Functions to cast to/from FP8. + * \brief Functions to cast to/from FP8/MXFP8. */ #ifndef TRANSFORMER_ENGINE_CAST_H_ @@ -17,21 +17,200 @@ extern "C" { #endif -/*! \brief Cast tensor to FP8. +/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) + * The implementation is per the microscaling format MXFP8 defined by the OCP specification: + * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * - * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8 tensor. - * \param[in] stream CUDA stream used for the operation. + * Supported modes of scaling (live scaling): + * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: + * - the scaled output tensor + * - the corresponding scaling factors + * The scaling factors are computed for blocks of the shape [1,32] + * (i.e., each scaling factor spans 32 contiguous elements along rows). + * + * 2) Columwise scaling (along the dim=1) computes one set of the output data. + * The scaling factors are computed for blocks of the shape [32,1] + * (i.e., each scaling factor spans 32 contiguous elements along columns). + * + * 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) + * computes two sets of the output data: both 1) and 2). + * + * The shape of the MX block must be specified in the 'output' argument, + * and can be either [1,32] or [32,1] as no other shapes are currently supported. + * + * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter + * of the output tensor should be set to 0. + */ + +/*! \brief Casts input tensor to FP8/MXFP8. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] noop Noop tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream); + +/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workplace, cudaStream_t stream); + +/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); -/*! \brief Cast tensor from FP8. +/*! \brief Casts input tensor from reduced to higher precision. + * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, + * the block dequantization (MXFP8) of the specified shape of the block will be used. + * In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise + * data of the output tensor, regardless of whether the row- or columnwise scaling is used. * - * \param[in] input Input tensor to be cast. - * \param[out] output Output tensor. + * \param[in] input Input FP8/MXFP8 tensor to be cast. + * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index ea3bdcd14e..678ffe9191 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -17,11 +17,26 @@ extern "C" { #endif +/*! \brief Transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); -void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, - NVTETensor cast_output, NVTETensor transposed_output, +/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 8e0d017a0d..293c57526d 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -53,6 +53,8 @@ class CommOverlapCore { int _cga_size; int _use_ce; int _ub_reg; + int _gemm_priority; + int _comm_priority; bool _atomic_gemm{false}; bool _is_p2p{false}; @@ -65,10 +67,13 @@ class CommOverlapCore { cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: + CommOverlapCore() {} // dummy constructor for exposing type to Python + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool use_ce, bool atomic_gemm); + int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm); virtual ~CommOverlapCore(); @@ -77,25 +82,76 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, + const std::vector &shape); + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { protected: int _rs_kernel_type; + bool _rs_overlap_first_gemm; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; public: + CommOverlapBase() {} // dummy constructor for exposing type to Python + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); virtual ~CommOverlapBase(); @@ -103,97 +159,124 @@ class CommOverlapBase : public CommOverlapCore { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main); + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { protected: bool _is_reduce_scatter{false}; bool _use_multiatomic_ag{false}; - + bool _aggregate; int _next_rank; int _prev_rank; int _rank_round_tp; - int _aggregate; int _num_ubuf_chunks; int _self_chunk_id; - std::vector _ubufs; - - cudaStream_t _stream_send; + std::vector _stream_send; cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; public: + CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, - int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, - bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); virtual ~CommOverlapP2PBase(); + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); + + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; /* ** Split AllGather + GEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + cudaStream_t stream_main) override; }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index a076a4e89a..b30a6e1338 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -28,16 +28,10 @@ extern "C" { * \param[in] amax_history History of maximum absolute values. * Shape: [history_length, num_scales] * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] - * \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales] - * \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be - * empty, in which case all scale_inv entries are updated. - * Shape: [num_scales] * \param[out] updated_amax_history Updated history of maximum absolute values. * Shape: [history_length, num_scales] * \param[out] updated_scale Updated scaling factor for casting to FP8. * Shape: [num_scales] - * \param[out] updated_scale_inv Updated scaling factor for casting from FP8. - * Shape: [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -45,9 +39,8 @@ extern "C" { * \param[in] stream CUDA stream. */ void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, - NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. @@ -55,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Operations performed include, updating the most recent amax history * with the relevant segment of global reduction buffer if it's not 0, * rotating the amax history based on the rule below, and updating the - * scales and scale_invs. + * scales. * * The amax history is rotated by -1 (e.g. the first entry shifts to * the last, the last entry shifts to the second to last) and the @@ -69,8 +62,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Shape: num_tensors x [history_length, num_scales] * \param[in,out] scales List of scaling factors for casting to FP8. * Shape: num_tensors x [num_scales] - * \param[in,out] scale_invs List of scaling factors for casting from FP8. - * Shape: num_tensors x [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -79,8 +70,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( */ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h new file mode 100644 index 0000000000..de5a11eb73 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast.h + * \brief Functions to cast to/from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_ +#define TRANSFORMER_ENGINE_SWIZZLE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] input Input tensor with non-swizzled scale_inv. + * \param[in,out] output Output tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_SWIZZLE_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 99b3508362..e393dbffc4 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -30,6 +30,7 @@ enum NVTEDType { kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ + kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */ kNVTENumTypes /*!< Number of supported types */ }; @@ -43,6 +44,42 @@ struct NVTEShape { size_t ndim; }; +/*! \struct NVTEBasicTensor + * \brief A basic tensor type used to populate parameters of NVTETensor. + * It does not own the memory it points to. + */ +struct NVTEBasicTensor { + void *data_ptr; + NVTEDType dtype; + NVTEShape shape; +}; + +/*! \enum NVTETensorParam + * \brief Indicates the kind of the tensor parameter to set/get. + */ +enum NVTETensorParam { + kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */ + kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */ + kNVTEScale = 2, /*!< Scale tensor */ + kNVTEAmax = 3, /*!< Amax tensor */ + kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ + kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTENumTensorParams +}; + +/*! \enum NVTEScalingMode + * \brief Granularity of scaling: + */ +enum NVTEScalingMode { + /*! Single scale per tensor, computed in delayed manner. + Used also for high precision data, without scaling */ + NVTE_DELAYED_TENSOR_SCALING = 0, + /*! Single scale per block of 32 elements consecutive in either + rowwise or columnwise direction */ + NVTE_MXFP8_1D_SCALING = 1, + NVTE_INVALID_SCALING +}; + /*! \brief TE Tensor type * * NVTETensor is a contiguous tensor type storing a pointer @@ -53,21 +90,15 @@ typedef void *NVTETensor; /*! \brief Create a new TE tensor. * - * Create a new TE tensor with a given shape, datatype and data. + * Create a new TE tensor. Before use its parameters need to be set. * TE tensors are just wrappers on top of raw data and do not * own memory. * - * \param[in] dptr Pointer to the tensor data. - * \param[in] shape Shape of the tensor. - * \param[in] dtype Data type of the tensor. - * \param[in] amax_dptr Pointer to the AMAX value. - * \param[in] scale_dptr Pointer to the scale value. - * \param[in] scale_inv_dptr Pointer to the inverse of scale value. + * \param[in] scaling_mode Scaling mode of the tensor. * * \return A new TE tensor. */ -NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, - float *amax_dptr, float *scale_dptr, float *scale_inv_dptr); +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); /*! \brief Destroy a TE tensor. * @@ -78,14 +109,22 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType */ void nvte_destroy_tensor(NVTETensor tensor); -/*! \brief Get a raw pointer to the tensor's data. +/*! \brief Get a raw pointer to the tensor's rowwise data. * * \param[in] tensor Tensor. * - * \return A raw pointer to tensor's data. + * \return A raw pointer to tensor's rowwise data. */ void *nvte_tensor_data(const NVTETensor tensor); +/*! \brief Get a raw pointer to the tensor's columnwise data. + * + * \param[in] tensor Tensor. + * + * \return A raw pointer to tensor's columnwise data. + */ +void *nvte_tensor_columnwise_data(const NVTETensor tensor); + /*! \brief Get a tensor's data shape. * * \param[in] tensor Tensor. @@ -94,6 +133,14 @@ void *nvte_tensor_data(const NVTETensor tensor); */ NVTEShape nvte_tensor_shape(const NVTETensor tensor); +/*! \brief Get a tensor's data shape. + * + * \param[in] tensor Tensor. + * + * \return A shape of the input tensor. + */ +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor); + /*! \brief Get a tensor's number of dimensions. * * \param[in] tensor Tensor. @@ -159,6 +206,46 @@ float *nvte_tensor_scale(const NVTETensor tensor); */ float *nvte_tensor_scale_inv(const NVTETensor tensor); +/*! \brief Get a tensor's scale_inv shape. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor); + +/*! \brief Reset tensor value to zero. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); + +/*! \brief Set a parameter of the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set. + */ +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param); + +/*! \brief Get a value of the parameter of the tensor. + * + * \param[in] tensor Tensor. + * \param[in] param_name The parameter to be set. + */ +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name); + +/*! \brief Get the granularity of scaling of this tensor. + * + * \param[in] tensor Tensor. + * + * \return A struct containing the granularity of tensor's scaling. + */ +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor); + /*! \struct NVTETensorPack \brief Pack of tensors, generally used for auxiliary outputs. */ @@ -201,6 +288,7 @@ enum class DType { kBFloat16 = 5, kFloat8E4M3 = 6, kFloat8E5M2 = 7, + kFloat8E8M0 = 8, kNumTypes }; @@ -220,12 +308,23 @@ class TensorWrapper { * \param[in] dtype Data type of the tensor. * \param[in] amax_dptr Pointer to the AMAX value. * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, - float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr) - : tensor_(nvte_create_tensor(dptr, shape, static_cast(dtype), amax_dptr, - scale_dptr, scale_inv_dptr)) {} + float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, + const NVTEShape scale_inv_shape = defaultShape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { + tensor_ = nvte_create_tensor(scaling_mode); + NVTEBasicTensor data = {dptr, static_cast(dtype), shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); + NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); + NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); + NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); + } /*! \brief Constructs new TensorWrapper. * @@ -238,19 +337,23 @@ class TensorWrapper { * \param[in] dtype Data type of the tensor. * \param[in] amax_dptr Pointer to the AMAX value. * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ TensorWrapper(void *dptr, const std::vector &shape, const DType dtype, float *amax_dptr = nullptr, float *scale_dptr = nullptr, - float *scale_inv_dptr = nullptr) + float *scale_inv_dptr = nullptr, const std::vector &scale_inv_shape = {1}, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, - scale_inv_dptr) {} + scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, + scaling_mode) {} /*! \brief Constructs new empty TensorWrapper. * * Create a new empty TE tensor which holds nothing. */ - TensorWrapper() : TensorWrapper(nullptr, std::vector(), DType::kFloat32) {} + explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_tensor(scaling_mode)) {} /*! \brief TensorWrapper destructor. */ ~TensorWrapper() { nvte_destroy_tensor(tensor_); } @@ -283,6 +386,70 @@ class TensorWrapper { return *this; } + // Parameter setters + template + TensorWrapper &set_parameter(const NVTETensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_tensor_param(&tensor_, param, &data); + return *this; + } + + template + TensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEScale, dptr, type, shape); + } + + template + TensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEAmax, dptr, type, shape); + } + + template + TensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseScaleInv, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); + } + + // Parameter getters + + NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { + return nvte_get_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTERowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEColumnwiseScaleInv); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -298,6 +465,15 @@ class TensorWrapper { return nvte_tensor_shape(tensor_); } + /*! \brief Get the shape of this TensorWrapper. + * + * \return Shape of this TensorWrapper. + */ + const NVTEShape columnwise_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_columnwise_shape(tensor_); + } + /*! \brief Get the size of this TensorWrapper in the given dimension. * * \param[in] size_t Dimension index. @@ -366,6 +542,15 @@ class TensorWrapper { return nvte_tensor_data(tensor_); } + /*! \brief Get a raw pointer to the tensor's data. + * + * \return A raw pointer to tensor's data. + */ + void *columnwise_dptr() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_columnwise_data(tensor_); + } + /*! \brief Get a pointer to the tensor's amax data. * * \return A pointer to tensor's amax data. @@ -393,7 +578,34 @@ class TensorWrapper { return nvte_tensor_scale_inv(tensor_); } + /*! \brief Get the scale_inv_shape of this TensorWrapper. + * + * \return scale_inv_shape of this TensorWrapper. + */ + const NVTEShape scale_inv_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_scale_inv_shape(tensor_); + } + + /*! \brief Get a scaling mode of the tensor. + * + * \return Scaling mode of the tensor. + */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_tensor_scaling_mode(tensor_); + } + + void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = {&defaultData, 1}; + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { return {s.data(), s.size()}; } + /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; }; diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 781f171cd8..a7db5cba47 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -20,16 +20,16 @@ extern "C" { /*! \brief Cast and transpose the input. * * This function casts the input and produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * - rowwise data in `output` is the result of the cast + * - columnwise data in `output` is the transposed result of the cast. * - * \param[in] input Input tensor of shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, cudaStream_t stream); +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Transpose the input. * @@ -41,25 +41,24 @@ void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaSt /*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension. * - * This function casts the input and produces 3 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * This function casts the input and produces 2 results: + * - `output` is the result of the cast (rowwise data) and transposed cast (columnwise data) * - `dbias` is the result of the reduction of the input along the first dimension. * * Calling this function with workspace being an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * - * \param[in] input Input tensor of shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the input along the - * first dimension. Shape: [H]. - * \param[out] workspace Workspace tensor. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[out] dbias Result of the reduction of the input along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream); +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream); /*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension. * @@ -82,102 +81,242 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp /*! \brief Cast and transpose multiple tensors. * - * This function casts each input tensor and produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. - * - * \param[in] num_tensors Number of tensors. - * \param[in] input_list List of 2D input tensors. - * \param[in,out] cast_output_list List of casted tensors. Dimensions - * match tensors in input_list. - * \param[in,out] transposed_output_list List of casted and transposed - * tensors. Dimensions are transpose - * of tensors in input_list. - * \param[in] stream CUDA stream used for the operation. + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of casted tensors. Dimensions + * of their rowwise data members match + * tensors in input_list. Dimensions of + * their columnwise data members are + * transposed. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, - NVTETensor* cast_output_list, NVTETensor* transposed_output_list, - cudaStream_t stream); + NVTETensor* output_list, cudaStream_t stream); -/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally, - * reduce the result of the SiLU backward along the first dimension. +/*! \brief Compute backward of GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the GeLU backward along the first dimension. * - * This function produces 3 results: - * - `cast_output` is equal to `cast(dact(input))` - * - `transposed_output` is equal to `transpose(cast(dact(input)))` + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` * - `dbias` is equal to `reduce(dact(input), axis=0)` * * Calling this function with workspace being an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] input Input tensor of shape [N, H]. - * \param[in] act_input Tensor used as input to the forward of SiLU operation. + * \param[in] act_input Tensor used as input for the operation of forward activation. * Shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the dSiLU(input) along the + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the * first dimension. Shape: [H]. * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. - - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ - void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the SiLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of the Quick GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Quick GeLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of the Squared ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Squared ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); -/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output. +/*! \brief Computes the gated GeLU activation of the input, additionally casts and transposes + * the output. * * This function produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` * * \param[in] input Input tensor of shape [N, H]. - * \param[in] gated_act_input Tensor used as input to the forward of GeGLU operation. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. * Shape [N, H * 2]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H * 2]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. * \param[in] stream CUDA stream used for the operation. - - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ - void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 89e2e9feec..7ef3ac44e7 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -15,6 +15,7 @@ #include #include "transformer_engine/normalization.h" +#include "transformer_engine/transformer_engine.h" /* @@ -38,13 +39,21 @@ Compute always in FP32 namespace transformer_engine { namespace normalization { -TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, - DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, - bool zero_centered_gamma, bool is_tuned) { +cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { + return training ? cudnn_frontend::NormFwdPhase_t::TRAINING + : cudnn_frontend::NormFwdPhase_t::INFERENCE; +} + +TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, + NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, + uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, + bool is_tuned, NVTEScalingMode mode, bool training) { + // TODO: Add scaling_mode to general_key is needed uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | - (uint32_t(zero_centered_gamma) << 16); + (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | + (uint32_t(mode) << 19) | (uint32_t(training) << 22); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } @@ -64,8 +73,8 @@ TeNormalizationPlan::TeNormalizationPlan( kernel_params.fp8_out = is_fp8_dtype(otype); } // TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those - auto key = - get_key(NormType, NormStage, wtype, itype, otype, ctype, 0, hidden_size, false, is_tuned); + auto key = get_key(NVTE_Norm_Backend::Te, NormType, NormStage, wtype, itype, otype, ctype, 0, + hidden_size, false, is_tuned); _kernel = KernelRegistry::getKernel(key); this->_build(); @@ -179,13 +188,25 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor DType wtype, DType itype, DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, - const bool zero_centered_gamma) - : _fp8_out(is_fp8_dtype(otype)), _zero_centered(zero_centered_gamma) { + const bool zero_centered_gamma, + const NVTEScalingMode mode, bool training) + : _fp8_out(is_fp8_dtype(otype)), + _zero_centered(zero_centered_gamma), + _training(training), + _norm_stage(NormStage), + _norm_type(NormType) { static_assert(CUDNN_FRONTEND_VERSION >= 10601, "CUDNN_FRONTEND_VERSION should be at least 1.6.1!"); namespace fe = cudnn_frontend; + if (is_tensor_scaling(mode)) { + _ndim_scale_block = 0; + } else { + NVTE_CHECK(mode == NVTE_MXFP8_1D_SCALING, "Unsupported scaling mode."); + _ndim_scale_block = 1; + } + _scalar_dptr = std::make_unique(typeToSize(wtype)); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); @@ -213,7 +234,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_dim({1, hidden_dim, 1, 1}) .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) .set_data_type(get_cudnn_fe_dtype(wtype))); - if (zero_centered_gamma) { + if (_zero_centered) { _scalar_offset = _graph.tensor(fe::graph::Tensor_attributes() .set_name("one") .set_dim({1, 1, 1, 1}) @@ -230,59 +251,97 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor } // Create graph computation nodes - if (NormStage == NVTE_Norm_Stage::Forward) { + if (_norm_stage == NVTE_Norm_Stage::Forward) { _eps = _graph.tensor(fe::graph::Tensor_attributes() .set_name("epsilon") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ctype)) .set_is_pass_by_value(true)); - if (NormType == NVTE_Norm_Type::LayerNorm) { + if (_norm_type == NVTE_Norm_Type::LayerNorm) { _beta = _graph.tensor(fe::graph::Tensor_attributes() .set_name("bias") .set_dim({1, hidden_dim, 1, 1}) .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) .set_data_type(get_cudnn_fe_dtype(wtype))); auto norm_options = fe::graph::Layernorm_attributes() - .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_forward_phase(get_cudnn_forward_phase(_training)) .set_epsilon(_eps) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options); std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]); - _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); - } else if (NormType == NVTE_Norm_Type::RMSNorm) { + if (_training) _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else { auto norm_options = fe::graph::Rmsnorm_attributes() - .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_forward_phase(get_cudnn_forward_phase(_training)) .set_epsilon(_eps) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); auto ret = _graph.rmsnorm(_x, _gamma, norm_options); std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]); } - _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + if (_training) _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); const auto ZDtype = _fp8_out ? ctype : otype; _z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype)); if (_fp8_out) { - // create a scale node - _z_scale = _graph.tensor(fe::graph::Tensor_attributes() - .set_name("z_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ctype))); - auto z_scale_options = fe::graph::Pointwise_attributes() - .set_mode(fe::PointwiseMode_t::MUL) - .set_compute_data_type(get_cudnn_fe_dtype(ctype)); - _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); - - _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); - - // create an amax reduction node - _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() - .set_mode(fe::ReductionMode_t::AMAX) - .set_compute_data_type(get_cudnn_fe_dtype(ctype))); - _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + if (_ndim_scale_block == 0) { // tensor_scaling + // create a scale node + _z_scale = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("z_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + auto z_scale_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); + + _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + + // create an amax reduction node + _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() + .set_mode(fe::ReductionMode_t::AMAX) + .set_compute_data_type(get_cudnn_fe_dtype(ctype))); + _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + _one_for_div = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("one_for_div") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype)) + .set_is_pass_by_value(true)); + auto div_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::DIV) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_scale_inv = _graph.pointwise(_one_for_div, _z_scale, div_options); + _z_scale_inv->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else if (_ndim_scale_block == 1) { // 1d block scaling + auto z_2d = _graph.reshape(_z, fe::graph::Reshape_attributes()); + z_2d->set_dim({batch_dim, hidden_dim}); + + auto mx_quantize_row_opts = fe::graph::Block_scale_quantize_attributes() + .set_block_size(32) + .set_axis(1) + .set_transpose(false); + auto bs_row_ret = _graph.block_scale_quantize(z_2d, mx_quantize_row_opts); + std::tie(_z_mx_row, _sf_row) = std::make_tuple(bs_row_ret[0], bs_row_ret[1]); + _z_mx_row->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _sf_row->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); //TODO + + if (_training) { + auto mx_quantize_col_opts = fe::graph::Block_scale_quantize_attributes() + .set_block_size(32) + .set_axis(0) + .set_transpose(false); + auto bs_col_ret = _graph.block_scale_quantize(z_2d, mx_quantize_col_opts); + std::tie(_z_mx_col, _sf_col) = std::make_tuple(bs_col_ret[0], bs_col_ret[1]); + _z_mx_col->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _sf_col->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); + } + } else { + NVTE_ERROR("Unsupported scaling mode."); + } } } else { _dz = _graph.tensor(fe::graph::Tensor_attributes() @@ -299,7 +358,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_dim({batch_dim, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ctype))); - if (NormType == NVTE_Norm_Type::LayerNorm) { + if (_norm_type == NVTE_Norm_Type::LayerNorm) { auto norm_options = fe::graph::Layernorm_backward_attributes() .set_saved_mean_and_inv_variance(_mean, _rsigma) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); @@ -341,10 +400,14 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, cudaStream_t stream) { // Binding data pointers to graph tensors - _variant_pack = {{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_eps, eps_dptr}}; + _variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}}; - // layernorm should have valid mean_dptr and beta_dptr - if (mean_dptr && beta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_beta, beta_dptr}}); + if (_training) _variant_pack.insert({{_rsigma, rsigma_dptr}}); + + if (_norm_type == NVTE_Norm_Type::LayerNorm) { + _variant_pack.insert({{_beta, beta_dptr}}); + if (_training) _variant_pack.insert({{_mean, mean_dptr}}); + } if (_zero_centered) _variant_pack.insert( @@ -352,16 +415,24 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, else _variant_pack.insert({{_gamma, gamma_dptr}}); - if (_fp8_out) - _variant_pack.insert( - {{_z_scale, z->scale.dptr}, {_amax, z->amax.dptr}, {_z_fp8, z->data.dptr}}); - else + if (_fp8_out && _ndim_scale_block == 0) { + _variant_pack.insert({{_one_for_div, reinterpret_cast(_one_dptr.get())}, + {_z_scale, z->scale.dptr}, + {_z_scale_inv, z->scale_inv.dptr}, + {_amax, z->amax.dptr}, + {_z_fp8, z->data.dptr}}); + } else if (_fp8_out && _ndim_scale_block == 1) { + _variant_pack.insert({{_z_mx_row, z->data.dptr}, {_sf_row, z->scale_inv.dptr}}); + if (_training) + _variant_pack.insert( + {{_z_mx_col, z->columnwise_data.dptr}, {_sf_col, z->columnwise_scale_inv.dptr}}); + } else { _variant_pack.insert({{_z, z->data.dptr}}); + } // Execute the computation NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); - if (_fp8_out) update_tensor_scale_inv(z, stream); } void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, @@ -389,11 +460,12 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, - const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned) { + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, + const NVTEScalingMode mode, const bool training) { const DType ctype = DType::kFloat32; bool is_tuned = is_aligned && (batch_size % 4 == 0); - auto key = get_key(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, - zero_centered_gamma, is_tuned); + auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, + hidden_size, zero_centered_gamma, is_tuned, mode, training); auto it = normalizationPlanMap.find(key); if (it != normalizationPlanMap.end()) { @@ -404,7 +476,7 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( if (NormBackend == NVTE_Norm_Backend::Cudnn) { plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, - zero_centered_gamma); + zero_centered_gamma, mode, training); } else if (NormStage == NVTE_Norm_Stage::Forward) { plan = std::make_unique>( NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index f366ba26db..ea0450f1c2 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -154,9 +154,12 @@ struct TupleHash { } }; -TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, - DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, - bool zero_centered_gamma, bool is_tuned); +// Note: the default mode here should match with the default mode with QTensor +TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, + NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, + uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, + bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, + bool training = true); template class TeNormalizationRegistry { @@ -257,7 +260,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, - const bool zero_centered_gamma); + const bool zero_centered_gamma, const NVTEScalingMode mode, + const bool training); std::vector getWorkspaceShape() const override; @@ -273,10 +277,17 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { void _build() override; const bool _zero_centered, _fp8_out; + int _ndim_scale_block; + const NVTE_Norm_Stage _norm_stage; + const NVTE_Norm_Type _norm_type; std::unique_ptr _scalar_dptr; + std::unique_ptr _one_dptr = std::make_unique(1.0f); // FWD std::shared_ptr _x, _gamma_zero, _scalar_offset, _gamma, _beta, - _eps, _mean, _rsigma, _z, _z_scale, _amax, _z_fp8; + _eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8; + // MX FWD + std::shared_ptr _z_mx_row, _z_mx_col, _sf_row, _sf_col; + const bool _training; // BWD std::shared_ptr _dz, _dx, _dgamma, _dbeta; @@ -292,12 +303,11 @@ class NormalizationPlanRegistry { return instance; } - NormalizationPlanBase* getNormalizationPlan(NVTE_Norm_Backend NormBackend, - NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, - DType wtype, DType itype, DType otype, - const size_t batch_size, const size_t hidden_size, - const size_t sm_count, const bool zero_centered_gamma, - const bool is_aligned); + NormalizationPlanBase* getNormalizationPlan( + NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, + const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); private: NormalizationPlanRegistry() {} @@ -356,15 +366,12 @@ struct TypeToDType { static int \ register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \ TeNormalizationRegistry::registerFunction( \ - (get_key(NVTE_Norm_Type::NORM_TYPE, NVTE_Norm_Stage::NORM_STAGE, \ - (TypeToDType::value), (TypeToDType::value), \ - (TypeToDType::value), (TypeToDType::value), 0, HIDDEN_SIZE, \ - 0, IS_TUNED(LAUNCH_TYPE))), \ + (get_key(NVTE_Norm_Backend::Te, NVTE_Norm_Type::NORM_TYPE, \ + NVTE_Norm_Stage::NORM_STAGE, (TypeToDType::value), \ + (TypeToDType::value), (TypeToDType::value), \ + (TypeToDType::value), 0, HIDDEN_SIZE, 0, IS_TUNED(LAUNCH_TYPE))), \ FUNC_NAME) -// For FP8 only -void ComputeScaleInv(void* scale, void* scale_inv); - // Alignment check template bool is_ptr_aligned(const Args*... ptrs) { @@ -375,7 +382,6 @@ bool use_cudnn_norm_fwd(); bool use_cudnn_norm_bwd(); } // namespace normalization - } // namespace transformer_engine #endif diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index a412bae745..dae39d82bf 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include #include @@ -25,6 +26,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { + if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && + !is_block_scaling(z->scaling_mode)) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); + } + NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(gamma.data.shape == beta.data.shape); NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); @@ -51,7 +57,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - if (use_cudnn_norm_fwd()) { + bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + + if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; } else { @@ -59,6 +67,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, rsigma->data.dptr); } + + bool training = + is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward, gamma.data.dtype, // wtype @@ -66,18 +78,31 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; - } else { - NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); - plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, - reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, - workspace->data.dptr, stream); } + + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + NVTE_CHECK( + !is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr), + "Columnwise scale_inv must be allocated for NormFwdTraining!"); + plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + + // Compute FP8 transpose if required + if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { + Tensor transpose_data; + transpose_data.data = z->columnwise_data; + transpose_data.scaling_mode = z->scaling_mode; + nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), + stream); + } + return; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index dd4c8e580d..8519fe1b64 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -13,6 +13,7 @@ #include "../../common.h" #include "../common.h" #include "transformer_engine/normalization.h" +#include "transformer_engine/transpose.h" namespace transformer_engine { @@ -21,6 +22,11 @@ using namespace normalization; void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { + if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && + !is_block_scaling(z->scaling_mode)) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); + } + NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); @@ -39,17 +45,21 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens CheckOutputTensor(*rsigma, "rsigma"); } - Tensor empty; - NVTE_Norm_Backend norm_backend; bool is_aligned = true; - if (use_cudnn_norm_fwd()) { + bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + + bool training = + is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + + if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; } else { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, gamma.data.dtype, // wtype @@ -57,17 +67,29 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; - } else { - NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); - plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr, nullptr, - reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, - workspace->data.dptr, stream); + } + + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + NVTE_CHECK( + !is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr), + "Columnwise scale_inv must be allocated for NormFwdTraining!"); + plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr /*beta*/, nullptr /*mu*/, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + + // Compute FP8 transpose if required + if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { + Tensor transpose_data; + transpose_data.data = z->columnwise_data; + transpose_data.scaling_mode = z->scaling_mode; + nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), + stream); } return; @@ -101,8 +123,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const CheckOutputTensor(*dgamma, "dgamma"); } - Tensor empty; - NVTE_Norm_Backend norm_backend; bool is_aligned = true; if (use_cudnn_norm_bwd()) { @@ -128,8 +148,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const return; } else { NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); - plan->execute(x.data.dptr, gamma.data.dptr, nullptr, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, nullptr, dgamma->data.dptr, workspace->data.dptr, stream); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream); } return; } diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 2c9944439d..f68edf155c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -39,19 +39,22 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) -class _OverrideLinearPrecision(NamedTuple): +class Recipe: """ - Whether or not the execute the `fprop`, `dgrad`, and `wgrad` - GEMMs in higher precision when using FP8. + Base recipe class. """ - fprop: bool = False - dgrad: bool = False - wgrad: bool = False + def mxfp8(self): + """Whether the given recipe is MXFP8 block scaling.""" + return isinstance(self, MXFP8BlockScaling) + + def delayed(self): + """Whether the given recipe is delayed scaling.""" + return isinstance(self, DelayedScaling) @dataclass() -class DelayedScaling: +class DelayedScaling(Recipe): """ Use the delayed scaling factor strategy. Use scale factor from previous iteration and record amax history of `amax_history_len` steps. @@ -92,9 +95,6 @@ def scaling_factor_compute(amax: Tensor, recipe: DelayedScaling) -> Tensor where `Tensor` is a framework tensor type. - override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) - Whether or not to execute the `fprop`, `dgrad`, and `wgrad` - GEMMs (respectively) in higher precision when using FP8. reduce_amax: bool, default = `True` By default, if `torch.distributed` is initialized, the `amax` value for FP8 tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` @@ -137,7 +137,6 @@ def scaling_factor_compute(amax: Tensor, fp8_format: Format = Format.HYBRID amax_history_len: int = 1024 amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" - override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() scaling_factor_compute_algo: Optional[Callable] = None reduce_amax: bool = True fp8_dpa: bool = False @@ -145,10 +144,6 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert self.override_linear_precision in ( - (False, False, False), - (False, False, True), - ), "Only wgrad GEMM override is currently supported." if self.interval >= 0: warnings.warn( "`interval` argument is deprecated and unused. " @@ -161,7 +156,32 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " - f"wgrad_override={self.override_linear_precision.wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class MXFP8BlockScaling(Recipe): + """ + Use the current scaling factor strategy. + + Parameters + ---------- + margin : int, default = 0 + Margin for the scaling factor computation. + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + Controls the FP8 data format used during forward and backward + pass. + """ + + margin: int = 0 + fp8_format: Format = Format.E4M3 + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index b16bad9e6a..658ce054da 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -46,7 +46,6 @@ struct AmaxParam { int num_scale = 0; float* amax_history = nullptr; float* scale = nullptr; - float* scale_inv = nullptr; }; // dummy struct for kernel_bulk's other params @@ -83,10 +82,9 @@ constexpr size_t bsize = 256; * Grid dims: num_scales x 1 x 1 */ __global__ void __launch_bounds__(bsize) - kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr, - const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr, - float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length, - size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) { + kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr, + float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride, + AmaxComputeAlgo amax_compute_algo, float scaled_max) { const size_t tid = threadIdx.x; const size_t bid = blockIdx.x; @@ -135,7 +133,7 @@ __global__ void __launch_bounds__(bsize) } } - // Update scale and scale inverse + // Update scale if (tid == 0) { // Update scale float scale; @@ -152,15 +150,6 @@ __global__ void __launch_bounds__(bsize) scale = std::numeric_limits::max(); } updated_scale_ptr[bid] = scale; - - // Update scale inverse - float scale_inv; - if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { - scale_inv = 1 / scale; - } else { - scale_inv = scale_inv_ptr[bid]; - } - updated_scale_inv_ptr[bid] = scale_inv; } } @@ -232,7 +221,7 @@ __global__ void __launch_bounds__(bsize) } } - // Update scale and scale inverse + // Update scale if (tid == 0) { // Computing the scaling factor requires consideration of the following scenarios: // 1. amax == 0: @@ -259,7 +248,6 @@ __global__ void __launch_bounds__(bsize) scale = std::numeric_limits::max(); } p.param[bid].scale[count] = scale; - p.param[bid].scale_inv[count] = 1 / scale; } } } @@ -268,23 +256,12 @@ __global__ void __launch_bounds__(bsize) } // namespace -void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv, - const Tensor& scale_inv_mask, Tensor* updated_amax_history_, - Tensor* updated_scale_, Tensor* updated_scale_inv_, +void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, + Tensor* updated_amax_history_, Tensor* updated_scale_, const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { auto& updated_amax_history = *updated_amax_history_; auto& updated_scale = *updated_scale_; - auto& updated_scale_inv = *updated_scale_inv_; - - // Number of elements in tensor - auto numel = [](const Tensor& tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; // Check tensors NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(), @@ -293,18 +270,9 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons const size_t num_scales = amax_history.data.shape[1]; NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ", dtype_name(amax_history.data.dtype), "."); - NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ", - numel(scale), "."); + NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ", + scale.numel(), "."); NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), "."); - if (scale_inv_mask.data.dptr != nullptr) { - NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ", - numel(scale_inv), "."); - NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); - NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(scale_inv_mask), "."); - NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ", - dtype_name(scale_inv_mask.data.dtype), "."); - } NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", updated_amax_history.data.shape.size(), " dims."); NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ", @@ -313,14 +281,10 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons "but found ", updated_amax_history.data.shape[1]); NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ", dtype_name(updated_amax_history.data.dtype), "."); - NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale), "."); + NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ", + "but found ", updated_scale.numel(), "."); NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ", dtype_name(updated_scale.data.dtype), "."); - NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale_inv), "."); - NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ", - dtype_name(updated_scale_inv.data.dtype), "."); // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; @@ -340,11 +304,8 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons const size_t grid_size = num_scales; amax_and_scale_update_impl::kernel<<>>( static_cast(amax_history.data.dptr), static_cast(scale.data.dptr), - static_cast(scale_inv.data.dptr), - static_cast(scale_inv_mask.data.dptr), static_cast(updated_amax_history.data.dptr), - static_cast(updated_scale.data.dptr), - static_cast(updated_scale_inv.data.dptr), amax_history_length, num_scales, + static_cast(updated_scale.data.dptr), amax_history_length, num_scales, amax_compute_algo_, scaled_max); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -352,7 +313,6 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, - std::vector scale_invs, const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { using namespace transformer_engine; @@ -370,15 +330,6 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, // Expected maximum value after scale is applied const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); - // Number of elements in tensor - auto numel = [](const Tensor* tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor->data.shape) { - acc *= dim; - } - return acc; - }; - // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); size_t num_remaining_tensors = num_tensors; @@ -404,22 +355,21 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, dtype_name(amax_histories[i]->data.dtype), "."); NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ", amax_histories[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ", + NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ", amax_history_length * num_scale, " elements, ", "but found ", - numel(amax_histories[i]), "."); + amax_histories[i]->numel(), "."); NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ", dtype_name(scales[i]->data.dtype), "."); NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", - numel(scales[i]), "."); + NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ", + scales[i]->numel(), "."); // amax parameters kernel_num_scales += num_scale; p.param[pi].num_scale = num_scale; p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); p.param[pi].scale = static_cast(scales[i]->data.dptr); - p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); } // Launch CUDA kernel @@ -441,34 +391,30 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, } // namespace transformer_engine void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, - NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); using namespace transformer_engine; delayed_scaling_recipe::amax_and_scale_update( *reinterpret_cast(amax_history), *reinterpret_cast(scale), - *reinterpret_cast(scale_inv), *reinterpret_cast(scale_inv_mask), reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), - reinterpret_cast(updated_scale_inv), amax_compute_algo, - static_cast(fp8_dtype), margin, stream); + amax_compute_algo, static_cast(fp8_dtype), margin, stream); } void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); using namespace transformer_engine; size_t num_tensors = amax_histories.size(); - std::vector t_amax_histories, t_scales, t_scale_invs; + std::vector t_amax_histories, t_scales; for (size_t i = 0; i < num_tensors; i++) { t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); t_scales.push_back(reinterpret_cast(scales[i])); - t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); } delayed_scaling_recipe::amax_and_scale_update_after_reduction( *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, - t_scale_invs, amax_compute_algo, static_cast(fp8_dtype), margin, stream); + amax_compute_algo, static_cast(fp8_dtype), margin, stream); } diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu new file mode 100644 index 0000000000..a0fffc783c --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -0,0 +1,338 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace { + +constexpr int TB_DIM = 32; +constexpr int NEW_SF_TILE_DIM_K = 16; +constexpr int N_SF_PER_TD_PER_TILE = 4; + +// output is in ~K-major interleaved blocks +constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr int NEW_SF_TILE_DIM_M_I32 = 32; + +template +__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { + // inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] + // out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15] + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t new_regs[kVectorSize]; + int32_t* regs = reinterpret_cast(regs_vec); + +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { +#pragma unroll + for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) { + new_regs[i * N_SF_PER_TD_PER_TILE + j] = + (((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) | + (((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) | + (((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) | + (((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24); + } + } +#pragma unroll + for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; +} + +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // input is in M-major + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + + int m_tiles_in_tb = N_TILE_PER_TD; + int k_tiles_in_tb = TB_DIM; + if (blockIdx.x == gridDim.x - 1) { + k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; + } + if (blockIdx.y == gridDim.y - 1) { + m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; + } + + const int32_t* input_i32 = reinterpret_cast(input) + + blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + + blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + int32_t* output_i32[N_TILE_PER_TD]; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + + (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + } + extern __shared__ int slm[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && + threadIdx.y < k_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // local shuffle + regs_shuffle_with_bit_shifts(regs_vec); + + // store, regs -> shared + int tM = threadIdx.x * N_SF_PER_TD; + int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 + + tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int i = 0; i < N_SF_PER_TD; i++) { + /* TODO rotate_i */ + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + __align__(16) int4* output_v4i = reinterpret_cast(output_i32[i]); + __align__(16) int4* slm_v4i = + reinterpret_cast(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; + j += blockDim.x * blockDim.y) { + output_v4i[j] = slm_v4i[j]; + } + } +} + +template +__device__ inline void regs_shuffle(LType* regs_vec) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + if constexpr (N_TILE_PER_TD == 1) return; + + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t tmp[kVectorSize]; + int32_t* ptr = reinterpret_cast(regs_vec); +#pragma unroll + for (int i = 0; i < kVectorSize; i++) + tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i]; + +#pragma unroll + for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; +} + +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // input is in K-major + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; + + int n_tiles_in_tb = N_TILES_IN_TB; + const int K_i32 = K / 4; + if (blockIdx.x == gridDim.x - 1) { + n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; + } + + const int* input_i32 = reinterpret_cast(input) + + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; + int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + + blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + + extern __shared__ int4 slm_v4i[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // shuffle regs + regs_shuffle(regs_vec); + +// store, regs -> shared +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { + /* TODO rotate i */ + slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + __align__(16) int4* output_v4i = reinterpret_cast(output_i32); +#pragma unroll + for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { + output_v4i[i] = slm_v4i[i]; + } +} + +} // namespace + +namespace transformer_engine { + +void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { + if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); + } + + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + + auto& scaling_mode = input->scaling_mode; + + // 1D block scaling, row-wise or colum-wise + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + const int m = + input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; + const int k = + input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + dim3 block_size(TB_DIM, TB_DIM); + if (input->has_data()) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 2: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 1: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + if (input->has_columnwise_data()) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 2: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 1: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + + // 2D block scaling + } else { + NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + exit(-1); + } +} +} // namespace transformer_engine + +/* + * WIP (Phuong): + * - Opt for bank conflicts + * - Adding swizzle for 2d-block scaling. +*/ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_scaling_factors); + using namespace transformer_engine; + swizzle_scaling_factors(reinterpret_cast(input), reinterpret_cast(output), + stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 11e0e319ed..faf6ec990d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -6,71 +6,196 @@ #include +#include + #include "common.h" namespace transformer_engine { -size_t typeToSize(const transformer_engine::DType type) { +size_t typeToSize(const DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, return TypeInfo::size;); // NOLINT(*) } -bool is_fp8_dtype(const transformer_engine::DType t) { - return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2; +bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } + +std::string to_string(const DType type) { + switch (type) { + case DType::kByte: + return "Byte"; + case DType::kBFloat16: + return "BFloat16"; + case DType::kFloat16: + return "Float16"; + case DType::kFloat32: + return "Float32"; + case DType::kFloat8E4M3: + return "Float8E4M3"; + case DType::kFloat8E5M2: + return "Float8E5M2"; + case DType::kFloat8E8M0: + return "Float8E8M0"; + case DType::kInt32: + return "Int32"; + case DType::kInt64: + return "Int64"; + default: + return concat_strings("Invalid type ", static_cast(type)); + } +} + +std::string to_string(const NVTEScalingMode &mode) { + switch (mode) { + case NVTE_DELAYED_TENSOR_SCALING: + return "Delayed Tensor Scaling"; + case NVTE_MXFP8_1D_SCALING: + return "MXFP8 1D Scaling"; + case NVTE_INVALID_SCALING: + return "Invalid Scaling"; + } + return "Invalid Scaling"; +} + +void CheckNoopTensor(const Tensor &t, const std::string &name) { + if (t.data.dptr != nullptr) { + NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(), + "."); + NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name, + " noop. Expected kFloat32."); + } +} + +void CheckScaleTensorShape(const Tensor &t, const std::string &name) { + NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); + if (is_tensor_scaling(t.scaling_mode)) { + // per-tensor scaling + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected (1), got ", + t.columnwise_scale_inv.shape, ")"); + } + } else { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { + // Need (4, 128) alignment even for e8 scaling factor + auto block_alignment = std::vector{128ul, 4ul}; + size_t expected_x, expected_y, alignment; + + if (t.has_data()) { + alignment = block_alignment[0]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; + alignment = block_alignment[1]; + expected_y = + DIVUP(DIVUP(t.flat_last_dim(), static_cast(32)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + alignment = block_alignment[1]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(32)), alignment) * alignment; + alignment = block_alignment[0]; + expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } + } + } } void CheckInputTensor(const Tensor &t, const std::string &name) { - const DType type = t.data.dtype; + const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale."); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{1}); + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { - NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.scale_inv.dptr == nullptr, - "Scale_inv is not supported for non-FP8 input " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); } - NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!"); + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); + + CheckScaleTensorShape(t, name); } void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { - const DType type = t.data.dtype; + const DType type = t.dtype(); if (is_fp8_dtype(type)) { - // FP8 output needs to have scale, amax and scale_inv - NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor."); - NVTE_CHECK(t.amax.dtype == DType::kFloat32); - NVTE_CHECK(t.amax.shape == std::vector{1}); - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale."); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{1}); - NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale."); - NVTE_CHECK(t.scale.dtype == DType::kFloat32); - NVTE_CHECK(t.scale.shape == std::vector{1}); + // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor"); + NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", + to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); + NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, + " (expected 1 entry, got shape=", t.amax.shape, ")"); + } + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { - NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.scale_inv.dptr == nullptr, - "Scale_inv is not supported for non-FP8 output " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); } if (!allow_empty) { - NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!"); + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!"); } + + CheckScaleTensorShape(t, name); } } // namespace transformer_engine -NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax, - float *scale, float *scale_inv) { +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { transformer_engine::Tensor *ret = new transformer_engine::Tensor; - ret->data.dptr = dptr; - ret->data.shape = std::vector(shape.data, shape.data + shape.ndim); - ret->data.dtype = static_cast(dtype); - ret->amax.dptr = amax; - ret->scale.dptr = scale; - ret->scale_inv.dptr = scale_inv; + ret->scaling_mode = scaling_mode; return ret; } @@ -81,30 +206,65 @@ void nvte_destroy_tensor(NVTETensor tensor) { } NVTEDType nvte_tensor_type(const NVTETensor tensor) { + if (tensor == nullptr) return kNVTEFloat32; return static_cast( - reinterpret_cast(tensor)->data.dtype); + reinterpret_cast(tensor)->dtype()); } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; const auto &t = *reinterpret_cast(tensor); NVTEShape ret; + + // FP8 tensor keeps shape in rowwise data + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; + } + + // Get shape based on what data is available + if (t.has_data()) { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; + } + if (t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; + } + + // Tensor has no data ret.data = t.data.shape.data(); ret.ndim = t.data.shape.size(); return ret; } +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; +} + size_t nvte_tensor_ndim(const NVTETensor tensor) { + if (tensor == nullptr) return 0; const auto &t = *reinterpret_cast(tensor); return t.data.shape.size(); } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { + if (tensor == nullptr) return 0; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); return t.data.shape[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { + if (tensor == nullptr) return 0; const auto &t = *reinterpret_cast(tensor); size_t numel = 1; for (auto size : t.data.shape) { @@ -114,16 +274,25 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { } size_t nvte_tensor_element_size(const NVTETensor tensor) { + if (tensor == nullptr) return sizeof(float); const auto &t = *reinterpret_cast(tensor); return transformer_engine::typeToSize(t.data.dtype); } void *nvte_tensor_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); return t.data.dptr; } +void *nvte_tensor_columnwise_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_data.dptr; +} + float *nvte_tensor_amax(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, "Tensor's amax must have Float32 type!"); @@ -131,6 +300,7 @@ float *nvte_tensor_amax(const NVTETensor tensor) { } float *nvte_tensor_scale(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, "Tensor's scale must have Float32 type!"); @@ -138,12 +308,83 @@ float *nvte_tensor_scale(const NVTETensor tensor) { } float *nvte_tensor_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32, - "Tensor's inverse of scale must have Float32 type!"); return reinterpret_cast(t.scale_inv.dptr); } +void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_scale_inv.dptr; +} + +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.scale_inv.shape.data(); + ret.ndim = t.scale_inv.shape.size(); + return ret; +} + +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param) { + NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); + NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated."); + auto &t = *reinterpret_cast(*tensor); + switch (param_name) { + case kNVTERowwiseData: + t.data = *param; + break; + case kNVTEColumnwiseData: + t.columnwise_data = *param; + break; + case kNVTEScale: + t.scale = *param; + break; + case kNVTEAmax: + t.amax = *param; + break; + case kNVTERowwiseScaleInv: + t.scale_inv = *param; + break; + case kNVTEColumnwiseScaleInv: + t.columnwise_scale_inv = *param; + break; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { + if (tensor == nullptr) { + return {nullptr, kNVTEFloat32, {nullptr, 0}}; + } + const auto &t = *reinterpret_cast(tensor); + switch (param_name) { + case kNVTERowwiseData: + return t.data; + case kNVTEColumnwiseData: + return t.columnwise_data; + case kNVTEScale: + return t.scale; + case kNVTEAmax: + return t.amax; + case kNVTERowwiseScaleInv: + return t.scale_inv; + case kNVTEColumnwiseScaleInv: + return t.columnwise_scale_inv; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.scaling_mode; +} + void nvte_tensor_pack_create(NVTETensorPack *pack) { for (int i = 0; i < pack->MAX_SIZE; i++) { pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); @@ -156,3 +397,18 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) { delete t; } } + +void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { + const auto &t = *reinterpret_cast(tensor); + // Zero out tensor data if allocated + if (t.data.dptr != nullptr) { + size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); + cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + } + // Set amax to 0 if allocated + if (t.amax.dptr != nullptr) { + float zero = 0.0f; + cudaMemcpyAsync(t.amax.dptr, &zero, sizeof(float), cudaMemcpyHostToDevice, stream); + } + cudaStreamSynchronize(stream); +} diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index b49c61195e..4cdb39b70a 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -10,12 +10,12 @@ #include -#include "../common.h" #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "cast_transpose.h" -namespace transformer_engine { +namespace transformer_engine::detail { namespace { @@ -217,159 +217,143 @@ __global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( } // namespace -void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_, - Tensor *transposed_output_, cudaStream_t stream) { - Tensor &cast_output = *cast_output_; - Tensor &transposed_output = *transposed_output_; +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) { + Tensor &output = *output_; - // Check no-op flag - if (noop.data.dptr != nullptr) { - size_t numel = 1; - for (const auto &dim : noop.data.shape) { - numel *= dim; - } - NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); - NVTE_CHECK(noop.data.dtype == DType::kFloat32); - NVTE_CHECK(noop.data.dptr != nullptr); - } - - // Check tensor dims + CheckNoopTensor(noop, "cast_transpose_noop"); CheckInputTensor(input, "cast_transpose_input"); - CheckOutputTensor(cast_output, "cast_output"); - CheckOutputTensor(transposed_output, "transposed_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; - NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); - NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); - NVTE_CHECK(transposed_output.data.shape[0] == row_length, - "Wrong dimension of transposed output."); - NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output."); - - // Check tensor pointers - NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); - NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated."); - NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated."); - NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype, + CheckOutputTensor(output, "cast_transpose_output"); + + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output.has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output.has_columnwise_data(), "Output columnwise is not allocated"); + + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output.data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output.data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); + NVTE_CHECK(output.flat_first_dim() == num_rows && output.flat_last_dim() == row_length, + "Invalid output dimensions (expected ", std::vector{num_rows, row_length}, + ", got ", std::vector{output.flat_first_dim(), output.flat_last_dim()}, ")"); + + // Check that cast and transposed output data matches + NVTE_CHECK(output.data.dtype == output.columnwise_data.dtype, "Cast and transposed output types must match."); - NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr, - "Cast and transposed outputs need to share amax tensor."); - NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, - "Cast and transposed outputs need to share scale tensor."); - NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr, + NVTE_CHECK(output.scale_inv.dptr == output.columnwise_scale_inv.dptr, "Cast and transposed outputs need to share scale-inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output.data.dtype, OutputType, - constexpr const char *itype_name = TypeInfo::name; - constexpr const char *otype_name = TypeInfo::name; - constexpr size_t itype_size = sizeof(InputType); - constexpr size_t otype_size = sizeof(OutputType); - - // Choose between runtime-compiled or statically-compiled kernel - const bool aligned = - (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); - if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Pick kernel config - std::vector kernel_configs; - kernel_configs.reserve(16); - const size_t sm_count = static_cast(cuda::sm_count()); - auto add_config = [&](size_t load_size, size_t store_size) { - kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, - store_size, sm_count); - }; - add_config(8, 8); - add_config(4, 8); - add_config(8, 4); - add_config(4, 4); - add_config(2, 8); - add_config(8, 2); - add_config(2, 4); - add_config(4, 2); - add_config(2, 2); - add_config(1, 8); - add_config(8, 1); - add_config(1, 4); - add_config(4, 1); - add_config(1, 2); - add_config(2, 1); - add_config(1, 1); - const auto &kernel_config = - *std::min_element(kernel_configs.begin(), kernel_configs.end()); - NVTE_CHECK(kernel_config.valid, "invalid kernel config"); - const size_t load_size = kernel_config.load_size; - const size_t store_size = kernel_config.store_size; - const size_t num_blocks = kernel_config.num_blocks; - - // Compile NVRTC kernel if needed and launch - auto &rtc_manager = rtc::KernelManager::instance(); - const std::string kernel_label = concat_strings( - "cast_transpose" - ",itype=", - itype_name, ",otype=", otype_name, ",load_size=", load_size, - ",store_size=", store_size); - if (!rtc_manager.is_compiled(kernel_label)) { - std::string code = string_code_transpose_rtc_cast_transpose_cu; - code = regex_replace(code, "__ITYPE__", itype_name); - code = regex_replace(code, "__OTYPE__", otype_name); - code = regex_replace(code, "__LOAD_SIZE__", load_size); - code = regex_replace(code, "__STORE_SIZE__", store_size); - code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); - code = regex_replace(code, "__BLOCK_SIZE__", block_size); - rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, - "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + output.dtype(), OutputType, + if (is_delayed_tensor_scaling(output.scaling_mode)) { + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = + (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + const size_t sm_count = static_cast(cuda::sm_count()); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, + store_size, sm_count); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose" + ",itype=", + itype_name, ",otype=", otype_name, ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = + (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); } - rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - static_cast(cast_output.scale_inv.dptr), row_length, - num_rows); - } else { // Statically-compiled general kernel - constexpr size_t load_size = 4; - constexpr size_t store_size = 4; - constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; - constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; - const int num_blocks = - (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); - cast_transpose_general_kernel - <<>>( - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - static_cast(cast_output.scale_inv.dptr), row_length, num_rows); + } else { + NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); }); // NOLINT(*) ); // NOLINT(*) } -} // namespace transformer_engine +} // namespace transformer_engine::detail -void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, cudaStream_t stream) { +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; auto noop = Tensor(); - cast_transpose(*reinterpret_cast(input), noop, - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), stream); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), noop, + reinterpret_cast(output), stream); } -void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, - NVTETensor cast_output, NVTETensor transposed_output, +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_with_noop); using namespace transformer_engine; - cast_transpose(*reinterpret_cast(input), *reinterpret_cast(noop), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), stream); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h new file mode 100644 index 0000000000..ed9bd5f5f7 --- /dev/null +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -0,0 +1,28 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ + +#include "../common.h" + +namespace transformer_engine::detail { + +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream); + +template +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream); + +template +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, + cudaStream_t stream); + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index ed919c8b94..8347e117ce 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -8,18 +8,19 @@ #include #include -#include +#include +#include #include -#include "../common.h" #include "../util/math.h" #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "cast_transpose.h" namespace transformer_engine { -namespace { +namespace detail { // String with RTC kernel implementation #include "string_code_transpose_rtc_cast_transpose_fusion_cu.h" @@ -177,16 +178,31 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ Tensor *workspace, const int nvec_out) { - const size_t row_length = cast_output.data.shape[1]; - const size_t num_rows = cast_output.data.shape[0]; + const size_t row_length = cast_output.flat_last_dim(); + const size_t num_rows = cast_output.flat_first_dim(); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } } template @@ -248,11 +264,13 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reduce_dbias_num_rows); } -template +template __global__ void __launch_bounds__(cast_transpose_num_threads) cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length, const size_t num_rows, const size_t num_tiles) { + static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); using IType = typename Param::InputType; using IType2 = typename Param::InputType2; using OType = typename Param::OutputType; @@ -373,6 +391,8 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) if constexpr (IS_DACT) { after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * OP(act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + after_dact[j].data.elt[k] = OP(in[current_in ^ 1][j].data.elt[k], {}); } else { after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); } @@ -449,78 +469,96 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) } static const char *ActTypeToString[] = { - "NoAct", // 0 - "Sigmoid", // 1 - "GeLU", // 2 - "QGeLU", // 3 - "SiLU", // 4 - "ReLU", // 5 - "SReLU" // 6 + "none", // 0 + "sigmoid", // 1 + "dsigmoid", // 2 + "gelu", // 3 + "dgelu", // 4 + "qgelu", // 5 + "dqgelu", // 6 + "silu", // 7 + "dsilu", // 8 + "relu", // 9 + "drelu", // 10 + "srelu", // 11 + "dsrelu" // 12 }; template -int get_dactivation_type() { - if (OP == &sigmoid) { - return 1; - } else if (OP == &dgelu) { - return 2; - } else if (OP == &dqgelu) { - return 3; - } else if (OP == &dsilu) { - return 4; - } else if (OP == &drelu) { - return 5; - } else if (OP == &dsrelu) { - return 6; - } else { - return 0; +constexpr int get_activation_type() { + constexpr decltype(OP) ActivationList[] = { + nullptr, // 0 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 + }; +#pragma unroll + for (int i = 0; i < sizeof(ActivationList) / sizeof(ActivationList[0]); ++i) { + if (OP == ActivationList[i]) { + return i; + } } + return 0; } -template -void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output, - Tensor *transposed_output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (workspace->data.dptr != nullptr) { +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + // Check tensors, unless querying dbias workspace + if (!IS_DBIAS || workspace->data.dptr != nullptr) { CheckInputTensor(input, "cast_transpose_fused_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - if constexpr (IS_DBIAS) CheckOutputTensor(*dbias, "dbias"); - if constexpr (IS_DACT) CheckInputTensor(act_input, "act_input"); + CheckOutputTensor(*output, "output"); + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr && dbias->has_data()); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr && act_input->has_data()); + CheckInputTensor(*act_input, "act_input"); + } } - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output->has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output->has_columnwise_data(), "Output columnwise data is not allocated"); - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output->data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); + // Check that cast and transposed output data matches + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); } if constexpr (IS_DACT) { - NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); } TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output->data.dtype, OutputType, using InputType2 = InputType; + output->dtype(), OutputType, using InputType2 = InputType; using Param = CTDBiasDActParam; constexpr int itype_size = sizeof(InputType); @@ -584,8 +622,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * if (!jit_compiled) { num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); } if constexpr (IS_DBIAS) { + // Check workspace size + populate_cast_transpose_dbias_workspace_config(*output, workspace, nvec_out); if (workspace->data.dptr == nullptr) { - populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out); return; } } @@ -631,15 +670,15 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * Param param; param.input = reinterpret_cast(input.data.dptr); - param.output_c = reinterpret_cast(cast_output->data.dptr); - param.output_t = reinterpret_cast(transposed_output->data.dptr); - param.scale_ptr = reinterpret_cast(transposed_output->scale.dptr); - param.amax = reinterpret_cast(transposed_output->amax.dptr); - param.scale_inv = reinterpret_cast(cast_output->scale_inv.dptr); + param.output_c = reinterpret_cast(output->data.dptr); + param.output_t = reinterpret_cast(output->columnwise_data.dptr); + param.scale_ptr = reinterpret_cast(output->scale.dptr); + param.amax = reinterpret_cast(output->amax.dptr); + param.scale_inv = reinterpret_cast(output->scale_inv.dptr); if constexpr (IS_DBIAS) { param.workspace = reinterpret_cast(workspace->data.dptr); } if constexpr (IS_DACT) { - param.act_input = reinterpret_cast(act_input.data.dptr); + param.act_input = reinterpret_cast(act_input->data.dptr); } // Runtime-compiled tuned kernel @@ -648,9 +687,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * constexpr const char *itype2_name = TypeInfo::name; constexpr const char *otype_name = TypeInfo::name; - int dActType = 0; - if constexpr (IS_DACT) { - dActType = get_dactivation_type(); + int actType = 0; + if constexpr (IS_DACT || IS_ACT) { + actType = get_activation_type(); } // Compile NVRTC kernel if needed and launch @@ -660,7 +699,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * ",itype=", itype_name, ",itype2=", itype2_name, ",otype=", otype_name, ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS, - ",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]); + ",IS_DACT=", IS_DACT, ",IS_ACT=", IS_ACT, + ",activationType=", ActTypeToString[actType]); if (!rtc_manager.is_compiled(kernel_label)) { std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu; @@ -673,7 +713,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads); code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS); code = regex_replace(code, "__IS_DACT__", IS_DACT); - code = regex_replace(code, "__DACTIVATION_TYPE__", dActType); + code = regex_replace(code, "__IS_ACT__", IS_ACT); + code = regex_replace(code, "__ACTIVATION_TYPE__", actType); rtc_manager.compile( kernel_label, "cast_transpose_fusion_kernel_optimized", code, @@ -695,11 +736,11 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); cudaFuncSetAttribute( - cast_transpose_fused_kernel_notaligned, + cast_transpose_fused_kernel_notaligned, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - cast_transpose_fused_kernel_notaligned + cast_transpose_fused_kernel_notaligned <<>>( param, row_length, num_rows, num_tiles); } @@ -1101,43 +1142,39 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) template -void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, - Tensor *cast_output, Tensor *transposed_output, +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "dgated_act_cast_transpose_input"); CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); - CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output"); - CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output"); + CheckOutputTensor(*output, "dgated_act_cast_transpose_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + NVTE_CHECK(output->has_data() && output->has_columnwise_data(), + "Both rowwise and columnwise data need to be allocated."); + NVTE_CHECK(output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(output->columnwise_data.shape.size() == 2, "T output must have 2 dimensions."); const size_t row_length = input.data.shape[1]; const size_t num_rows = input.data.shape[0]; NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + NVTE_CHECK(output->data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(output->data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(output->columnwise_data.shape[0] == row_length * 2, "Wrong dimension of T output."); + NVTE_CHECK(output->columnwise_data.shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, "C and T outputs need to share scale inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output->data.dtype, OutputType, using InputType2 = InputType; + output->dtype(), OutputType, using InputType2 = InputType; /* dact fusion kernel uses more registers */ constexpr int desired_load_size_dact = 4; constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType); @@ -1168,11 +1205,11 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu <<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); } else { cudaFuncSetAttribute( @@ -1184,194 +1221,193 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu <<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); }); // NOLINT(*) ); // NOLINT(*) } -} // namespace + +// Explicit template instantiation +template void cast_transpose_fused( + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); +#define NVTE_INSTANTIATE_ACTIVATION(op) \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); +NVTE_INSTANTIATE_ACTIVATION(relu); +NVTE_INSTANTIATE_ACTIVATION(srelu); +NVTE_INSTANTIATE_ACTIVATION(gelu); +NVTE_INSTANTIATE_ACTIVATION(qgelu); +NVTE_INSTANTIATE_ACTIVATION(silu); +#undef NVTE_INSTANTIATE_ACTIVATION + +} // namespace detail } // namespace transformer_engine using ComputeType = typename transformer_engine::fp32; -void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; constexpr const NVTETensor activation_input = nullptr; - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(activation_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused( + *reinterpret_cast(input), reinterpret_cast(activation_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dgelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(act_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dsilu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(silu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(silu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &drelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(relu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(relu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dsrelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(srelu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(srelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dqgelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(qgelu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(qgelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dgelu; - constexpr auto Activation = &gelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, gelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dsilu; - constexpr auto Activation = &silu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, silu>( *reinterpret_cast(input), *reinterpret_cast(swiglu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &drelu; - constexpr auto Activation = &relu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, relu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dsrelu; - constexpr auto Activation = &srelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, srelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dqgelu; - constexpr auto Activation = &qgelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, qgelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 16894ad4b5..5cf316f45e 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -195,42 +195,44 @@ __global__ void __launch_bounds__(threads_per_block) } // namespace -void multi_cast_transpose(const std::vector input_list, - std::vector cast_output_list, - std::vector transposed_output_list, cudaStream_t stream) { +void multi_cast_transpose(const std::vector input_list, std::vector output_list, + cudaStream_t stream) { // Check that number of tensors is valid - NVTE_CHECK(cast_output_list.size() == input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(transposed_output_list.size() == input_list.size(), - "Number of input and T output tensors must match"); + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); if (input_list.empty()) { return; } // Check that tensor properties are valid DType itype = input_list[0]->data.dtype; - DType otype = cast_output_list[0]->data.dtype; + DType otype = output_list[0]->dtype(); for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { const auto& input = *input_list[tensor_id]; - const auto& cast_output = *cast_output_list[tensor_id]; - const auto& transposed_output = *transposed_output_list[tensor_id]; + const auto& output = *output_list[tensor_id]; CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id)); - CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id)); - CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_cast_transpose_output_" + std::to_string(tensor_id)); + //std::cout << *static_cast(output.data.dptr) << std::endl; + NVTE_CHECK(output.has_data() && output.has_columnwise_data(), + "Both rowwise and columnwise output data needs to be allocated."); NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); - NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match."); - NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "C output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "T output tensor types do not match."); - NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); - NVTE_CHECK(cast_output.data.shape == input.data.shape, - "C output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, - "T output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape[0] == input.data.shape[1], - "T output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape[1] == input.data.shape[0], - "T output tensor shape does not match input tensor."); + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions, but shape is ", + input.data.shape); + NVTE_CHECK(output.data.shape == input.data.shape, "C output tensor shape ", output.data.shape, + "does not match input tensor shape ", input.data.shape); + NVTE_CHECK(output.columnwise_data.shape.size() == 2, "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[0] == input.data.shape[1], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[1] == input.data.shape[0], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); } // Input matrices are divided into tiles @@ -287,11 +289,11 @@ void multi_cast_transpose(const std::vector input_list, // Add tensor to kernel argument struct const int pos = kernel_args.num_tensors; kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); - kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->data.dptr; - kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; - kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; - kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; - kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr; + kernel_args.output_c_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.output_t_list[pos] = output_list[tensor_id]->columnwise_data.dptr; + kernel_args.scale_list[pos] = output_list[tensor_id]->scale.dptr; + kernel_args.amax_list[pos] = output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = output_list[tensor_id]->scale_inv.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; @@ -327,15 +329,13 @@ void multi_cast_transpose(const std::vector input_list, } // namespace transformer_engine void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, - NVTETensor* cast_output_list, NVTETensor* transposed_output_list, - cudaStream_t stream) { + NVTETensor* output_list, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_cast_transpose); using namespace transformer_engine; - std::vector input_list_, cast_output_list_, transposed_output_list_; + std::vector input_list_, output_list_; for (size_t i = 0; i < num_tensors; ++i) { input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); - cast_output_list_.push_back(reinterpret_cast(cast_output_list[i])); - transposed_output_list_.push_back(reinterpret_cast(transposed_output_list[i])); + output_list_.push_back(reinterpret_cast(output_list[i])); } - multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream); + multi_cast_transpose(input_list_, output_list_, stream); } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index 2424247bbe..34359561aa 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -22,7 +22,9 @@ constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; constexpr bool IS_DBIAS = __IS_DBIAS__; constexpr bool IS_DACT = __IS_DACT__; -constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__; +constexpr bool IS_ACT = __IS_ACT__; +static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); +constexpr size_t ACT_TYPE = __ACTIVATION_TYPE__; constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType); constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType); @@ -33,14 +35,20 @@ using OVec = Vec; using Param = CTDBiasDActParam; using OP = CType (*)(const CType, const Empty &); -constexpr OP Activation[] = { +constexpr OP ActivationList[] = { nullptr, // 0 - &dsigmoid, // 1 - &dgelu, // 2 - &dqgelu, // 3 - &dsilu, // 4 - &drelu, // 5 - &dsrelu // 6 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 }; } // namespace @@ -175,7 +183,10 @@ __global__ void __launch_bounds__(BLOCK_SIZE) if constexpr (IS_DACT) { in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]) * - Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + ActivationList[ACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + in_cast_fp32[j].data.elt[k] = + ActivationList[ACT_TYPE](in[current_in ^ 1][j].data.elt[k], {}); } else { in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]); } diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 339748ead0..26740a3837 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -205,17 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); - // Number of elements in tensor - auto numel = [](const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto &dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; - if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), "."); + NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), "."); NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dptr != nullptr); } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index 39c702dade..fba3710beb 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include "../common.h" #include "../utils.cuh" @@ -376,8 +376,24 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + if (workspace->data.dptr == nullptr) { + // Set workspace size + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } } template @@ -426,10 +442,9 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor constexpr int nvec_in = desired_load_size / type_size; constexpr int nvec_out = desired_store_size / type_size; - if (workspace->data.dptr == nullptr) { - populate_transpose_dbias_workspace_config(input, workspace, nvec_out); - return; - } + // Check workspace size + populate_transpose_dbias_workspace_config(input, workspace, nvec_out); + if (workspace->data.dptr == nullptr) { return; } NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index e0c92c22cb..22a50025df 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -4,88 +4,144 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include +#include #include +#include +#include +#include + #include "../common.h" +#include "../transpose/cast_transpose.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" +#include "cast_kernels.cuh" +#include "dequantize_kernels.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; -namespace transformer_engine { + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; -namespace detail { + detail::quantize_helper(input, grad, nullptr, output, + dbias, workspace, stream); +} -struct Empty {}; +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; -__device__ inline fp32 identity(fp32 value, const Empty &) { return value; } + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; -struct DequantizeParam { - const fp32 *scale_inv; -}; + detail::quantize_helper(input, grad, noop, output, + dbias, workspace, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; -__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr const NVTETensor activation_input = nullptr; + + detail::quantize_helper( + activation_input, input, nullptr, output, dbias, workspace, stream); } -} // namespace detail - -void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision."); - - NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -} // namespace transformer_engine +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} -void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fp8_quantize); +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); using namespace transformer_engine; - fp8_quantize(*reinterpret_cast(input), reinterpret_cast(output), - stream); + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fp8_dequantize); +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; - fp8_dequantize(*reinterpret_cast(input), reinterpret_cast(output), - stream); + detail::dequantize_helper(*reinterpret_cast(input), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh new file mode 100644 index 0000000000..e2240ba658 --- /dev/null +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -0,0 +1,1091 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_gated_kernels.cuh + * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" + +namespace transformer_engine { + +template +__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { + return DIVUP(static_cast(N), static_cast(M)) * M; +} + +namespace gated_kernels { + +constexpr size_t ALIGNMENT_SIZE = 128; +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t out_gate_mem = buff_size_aligned_out; + constexpr size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + // uint64_t *mbar = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + if (next_it < ITERATIONS) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, {}) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + + OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + OType *out_act_colwise_sh = out_act_rowwise_sh; + OType *out_gate_colwise_sh = out_gate_rowwise_sh; + + if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + out_gate_colwise_sh = + reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + } + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act_rowwise = + reinterpret_cast(&tensor_map_output_act_rowwise); + const uint64_t *TMAP_output_gate_rowwise = + reinterpret_cast(&tensor_map_output_gate_rowwise); + const uint64_t *TMAP_output_act_colwise = + reinterpret_cast(&tensor_map_output_act_colwise); + const uint64_t *TMAP_output_gate_colwise = + reinterpret_cast(&tensor_map_output_gate_colwise); + + __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + const bool is_master_thread = (threadIdx.x == 0); + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + + // Prefetch data of the first stage + if (is_master_thread) { + // Initiate bulk tensor copy + // Grad + if constexpr (IS_DGATED) { + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), + TMAP_grad_in, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + } + + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), + TMAP_in_act, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[0]); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; + if (next_it < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + if constexpr (IS_DGATED) { + // Grad + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + } + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_it]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; + OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; + OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; + OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; + + // Assuming one iteration covers exactly 32 rows + const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; + const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; + + float after_dact_reg[BUFFER_STAGES_NUM]; + float after_dgate_reg[BUFFER_STAGES_NUM]; + float thread_Y_mx_block_amax = 0.0f; + float thread_Y_mx_block_amax_gate = 0.0f; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; + after_dgate_reg[stage] = act_x * grad_elt; + } else { + after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + } + + if constexpr (USE_ROWWISE_SCALING) { + if constexpr (IS_DGATED) { + // dgate + float amax = fabsf(after_dgate_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_gate_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dgate_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + float amax = fabsf(after_dact_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_act_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dact_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + + if constexpr (USE_COLWISE_SCALING) { + __builtin_assume(thread_Y_mx_block_amax >= 0); + __builtin_assume(thread_Y_mx_block_amax_gate >= 0); + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); + if constexpr (IS_DGATED) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool row_out_of_bounds = (row_base >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + if constexpr (IS_DGATED) { + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_gate_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dgate_reg[stage]); + } + } + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_act_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dact_reg[stage]); + } + } // endif USE_COLWISE_SCALING + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_rowwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_rowwise_sh_curr)); + } + } + + // dGeLU + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_colwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_colwise_sh_curr)); + } + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem); // + mbar_mem; + + cudaFuncSetAttribute( + cast_fp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_fp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + + if (USE_ROWWISE_SCALING) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + // TODO: Make more general + const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; + const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; + size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + e8m0_t *const scales_rowwise_ptr = + USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, + sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, + sizeof(OType)); + } + + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + cols, sizeof(OType)); + } + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; + + const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; + + cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "gated_act_input"); + CheckOutputTensor(*output, "gated_act_output"); + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); + NVTE_CHECK(input.data.shape[0] == output->data.shape[0], + "Input shape[0] must be equal to output shape[0]."); + NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, + "Input shape[1] must be 2x larger than output shape[1]."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], + output->data.shape[1], {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(grad, "dgated_act_grad"); + CheckInputTensor(input, "dgated_act_input"); + CheckOutputTensor(*output, "dgated_act_output"); + NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), + "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), + ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, + "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match. Input shape: ", input.data.shape, + ", output shape: ", output->data.shape, "."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + checkCuDriverContext(stream); + constexpr bool allow_empty = false; + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", allow_empty); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + if constexpr (IS_DGATED) { + CheckInputTensor(grad, "grad"); + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); + } + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + bool is_fp8_rowwise_output = true; + bool is_fp8_colwise_output = true; + if (output->has_data()) { + is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + if (output->has_columnwise_data()) { + is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + + const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; + + if (is_delayed_tensor_scaling(output->scaling_mode)) { + if (use_tma_kernels) { + cast_fp8_gated(grad, gated_input, output, stream); + } else { + if constexpr (IS_DGATED) { + cast_dgated(grad, gated_input, output, stream); + } else { + cast_gated(gated_input, output, stream); + } + } + } else if (is_mxfp_scaling(output->scaling_mode)) { + if (use_tma_kernels) { + cast_mxfp8_gated(grad, gated_input, output, stream); + } else { + NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + "by 32, got input of shape ", gated_input.data.shape); + } + } else { + NVTE_ERROR("Not supported scaling mode"); + } +} +} // namespace gated_kernels + +namespace detail { + +template +void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, + cudaStream_t stream) { + using namespace gated_kernels; + Tensor grad_empty_tensor; + const Tensor &grad_tensor = + IS_DGATED ? *(reinterpret_cast(grad)) : grad_empty_tensor; + const Tensor gated_input_tensor = *reinterpret_cast(gated_input); + Tensor *output_tensor = reinterpret_cast(output); + + if (is_supported_by_CC_100()) { + quantize_gated(grad_tensor, gated_input_tensor, + output_tensor, stream); + } else { + if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { + if constexpr (IS_DGATED) { + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + } else { + cast_gated(gated_input_tensor, output_tensor, stream); + } + } else { + // MX scaling + NVTE_ERROR("Not supported by the Arch < 10.0"); + } + } +} +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh new file mode 100644 index 0000000000..404babc745 --- /dev/null +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -0,0 +1,1251 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_kernels.cuh + * \brief CUDA kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ + +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + if (noop != nullptr && noop[0] == 1.0f) return; + } + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_rowwise[i].clear(); + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_colwise[i] = 0; + } + } + } + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; +#pragma unroll + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec act_in; + Vec out_c; + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + } + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + // Only single thread writes the computed scaling factor + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + float in_compute[SCALE_DIM_Y]; + + float amax = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } + } + in_compute[i] = elt; + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; + constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + + if (tid_rowwise_Y > 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + partial_dbias_rowwise[c].store_to( + &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); + } + } + __syncthreads(); + + if (tid_rowwise_Y == 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + Vec other_row_dbias; + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + + const int left_bound = dbias_rowwise_offset_X; + const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + +#pragma unroll + for (int i = 0; i < Y; ++i) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; + } + } + + // Vectorized store when all elements are inside the boundaries + if (right_bound < cols) { + partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + // Element-by-element store when some elements cross the boundaries + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); + } + } + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + } + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const int dbias_offset_Y = blockIdx.y + tid_Y; + const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const int dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const int chunk_offset_Y = block_offset_Y; + const int chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const int buff = iter % FP8_BUFFERS_NUM; + const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const int next_buff = next_iter % FP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const int dbias_offset_X = my_column; + const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % SHMEM_BUFFERS; + const int it_offset = iter * SHMEM_DIM; + + const int next_iter = iter + 1; + const int next_buff = next_iter % SHMEM_BUFFERS; + const int next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, + const int cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); +} + +template +static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, + const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + checkCuDriverContext(stream); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + const auto &input_shape = input.data.shape; + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(IType)); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, + cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + cast_mxfp8_2D_kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { + +using Empty = transformer_engine::Empty; + +__device__ inline float identity(float value, const Empty &) { return value; } + +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace { + +static bool is_full_tile_1D_tensor(const Tensor *const t) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + return isFullTile; +} + +bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr int TMA_bytes = 16; + const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); + return cols % alignment_requirement == 0; +} + +} // namespace + +// Supported by the Arch >= 10.0 +template +void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DBIAS && !IS_DACT) { + if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype())) { + // Aligned AND FP8 + cast_fp8_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype())) { + // Aligned AND FP8 (+dAct) + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize(input, act_input, noop, output, dbias, + workspace, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +// Supported by the Arch < 10.0 +template +void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + if (!is_delayed_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } +} + +template +void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + fp8_quantize_arch_ge_100(input, act_input, noop, output, + dbias, workspace, stream); + } else { + // Supported by the Arch < 10.0 + fp8_quantize_arch_l_100(input, act_input, noop, output, + dbias, workspace, stream); + } +} + +namespace detail { + +template +void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = reinterpret_cast(grad); + activation_input_tensor = reinterpret_cast(input); + } else { + // forward = input is activation input + input_tensor = reinterpret_cast(input); + activation_input_tensor = nullptr; + } + auto output_tensor = reinterpret_cast(output); + auto dbias_tensor = reinterpret_cast(dbias); + auto workspace_tensor = reinterpret_cast(workspace); + const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); + + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + } else if (output_tensor->has_data()) { + fp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace detail +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index cc9a659b5b..8b6bb52397 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -81,6 +81,26 @@ int sm_count(int device_id) { return cache[device_id]; } +void stream_priority_range(int *low_priority, int *high_priority, int device_id) { + static std::vector> cache(num_devices()); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + int ori_dev = current_device(); + if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id)); + int min_pri, max_pri; + NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri)); + if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev)); + cache[device_id] = std::make_pair(min_pri, max_pri); + }; + std::call_once(flags[device_id], init); + *low_priority = cache[device_id].first; + *high_priority = cache[device_id].second; +} + bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile time because the diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 33c2aea8d4..072eacd623 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -38,6 +38,16 @@ int sm_arch(int device_id = -1); */ int sm_count(int device_id = -1); +/* \brief Minimum and maximum stream priorities supported on device + * + * \param[in] device_id CUDA device (default is current device) + * + * \param[out] low_priority Lowest priority value on device. + * + * \param[out] high_priority Highest priority value on device. + */ +void stream_priority_range(int *low_priority, int *high_priority, int device_id = -1); + /* \brief CUDA Multicast support status for device * * \param[in] device_id CUDA device (default is current device) diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh new file mode 100644 index 0000000000..e529289640 --- /dev/null +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -0,0 +1,360 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_kernels.cuh + * \brief CUDA kernels to cast from MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ + +#include +#include +#include +#include + +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +namespace transformer_engine { + +namespace dequantization { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t BUFFERS_NUM = 2; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, + const size_t scales_stride) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + constexpr int iteration_zero = 0; + constexpr int buffer_zero = 0; + if (is_master_thread) { + const int chunk_stage_offset_Y = chunk_offset_Y; + const int chunk_stage_offset_X = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buffer_zero]), + reinterpret_cast(&tensor_map_input), chunk_stage_offset_X, + chunk_stage_offset_Y, &mbar[iteration_zero]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[iteration_zero], transaction_size); + + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[iteration_zero]); + } + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % BUFFERS_NUM; + const int next_iter = iter + 1; + if (next_iter < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_iter % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_buff]), + reinterpret_cast(&tensor_map_input), chunk_it_offset_x, + chunk_it_offset_y, &mbar[next_iter]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_iter], transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_iter]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + const int scale_offset_Y = + USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) + : (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y); + + const int scale_offset_X = + USE_ROWWISE_SCALING + ? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE) + : (scales_colwise_chunk_offset_X + tid_colwise_X); + + const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; + const e8m0_t biased_exponent = scales_ptr[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out; + + const int shmem_offset_y = thread_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + } + out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]); + } else { +#pragma unroll + for (int i = 0; i < BUFFER_DIM_Y; ++i) { + const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + detail::DequantizeParam p; + p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} + +static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + bool use_rowwise_scaling = input.has_data(); + bool use_colwise_scaling = input.has_columnwise_data(); + checkCuDriverContext(stream); + + const auto &input_shape = input.data.shape; + NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions."); + + if (use_rowwise_scaling) { + NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + } + + if (use_colwise_scaling) { + NVTE_CHECK(input.has_columnwise_data(), "Cannot dequantize tensor without columnwise data."); + NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); + } + + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + + const size_t unpadded_scales_Y_rowwise = rows; + const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + const size_t unpadded_scales_X_colwise = cols; + + const size_t scales_Y_rowwise = + DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + scale_tensor_alignment_Y_rowwise; + const size_t scales_X_rowwise = + DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + scale_tensor_alignment_X_rowwise; + const size_t scales_Y_colwise = + DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + scale_tensor_alignment_Y_colwise; + const size_t scales_X_colwise = + DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + scale_tensor_alignment_X_colwise; + + const e8m0_t *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + : reinterpret_cast(input.columnwise_scale_inv.dptr); + + const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + + const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; + + const dim3 block(THREADS_PER_CHUNK); + const dim3 grid(chunks_X, chunks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(OType)); + + dequantize_mxfp8_kernel + <<>>(tensor_map_input, tensor_map_output, scales_ptr, + rows, cols, scales_stride);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace dequantization + +namespace detail { + +void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if (is_tensor_scaling(input.scaling_mode)) { + dequantization::fp8_dequantize(input, output, stream); + } else if (is_mxfp_scaling(input.scaling_mode)) { + if (is_supported_by_CC_100()) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh new file mode 100644 index 0000000000..a22b930ecd --- /dev/null +++ b/transformer_engine/common/util/ptx.cuh @@ -0,0 +1,300 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ptx.cuh + * \brief BW PTX + */ + +#ifndef TRANSFORMER_ENGINE_PTX_CUH_ +#define TRANSFORMER_ENGINE_PTX_CUH_ + +#include +#include + +namespace transformer_engine { +namespace ptx { + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval +__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) + : "memory"); +} + +__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { + asm volatile("fence.mbarrier_init.release.cluster;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.shared::cta.global" + ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), + "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, + const uint64_t *src_shmem, + const uint32_t size) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), + "r"(src_shmem_ptr), "r"(size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, + uint64_t *src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( + tensor_map_ptr), + "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) + : "memory"); +} + +__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile( + "{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group 0;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} + +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { + asm volatile("cp.async.bulk.wait_group.read 1;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { + asm volatile("cp.async.bulk.wait_group.read 2;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { + asm volatile("cp.async.bulk.wait_group.read 4;"); +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +} // namespace ptx + +namespace { + +template +__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), + num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, + chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, + const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, + const size_t chunk_X2, const size_t chunk_Y2, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), + chunk_X2, chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx3( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2, + const void *src2, const size_t chunk_X2, const size_t chunk_Y2, void *dst3, const void *src3, + const size_t chunk_X3, const size_t chunk_Y3, const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), + chunk_X2, chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst3), + reinterpret_cast(src3), + chunk_X3, chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PTX_CUH_ diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 97c5bee2b1..b3087d1fb7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -14,66 +14,98 @@ #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def( \ + "get_stream_priority_range", \ + [](int device_id = -1) { \ + int low_pri, high_pri; \ + transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ + return std::make_pair(low_pri, high_pri); \ + }, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); #endif diff --git a/transformer_engine/common/util/system.h b/transformer_engine/common/util/system.h index e3a7164932..71c7ef3216 100644 --- a/transformer_engine/common/util/system.h +++ b/transformer_engine/common/util/system.h @@ -9,8 +9,6 @@ #include -#include "../common.h" - namespace transformer_engine { /*! \brief Get environment variable and convert to type diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index faf3ea0a61..420b9ed3bb 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -44,6 +44,13 @@ class VectorizedStorage { return *this; } inline __device__ ~VectorizedStorage() {} + + /* \brief Access to separate elements. */ + inline __device__ DType *separate() { return scratch_.separate; } + + inline __device__ const DType *separate() const { return scratch_.separate; } + + inline __device__ LType &aligned() { return scratch_.aligned; } }; // Returns const LType is DType is const @@ -167,9 +174,11 @@ constexpr int unary_kernel_threads = 512; template __launch_bounds__(unary_kernel_threads) __global__ - void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, - const size_t num_aligned_elements) { + void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, + const size_t N, const size_t num_aligned_elements) { + if (noop != nullptr && noop[0] == 1.0f) return; + VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; @@ -322,9 +331,9 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) template -void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, - cudaStream_t stream) { +void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, + const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, + const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -337,16 +346,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, output, scale, amax, scale_inv, params, N, N); + input, noop, output, scale, amax, scale_inv, params, N, N); break; } } @@ -395,18 +404,19 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; VectorizedLoader loader0(input + id_y * n * 2, n); VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); - ComputeType max = 0; - ComputeType s = 1; - if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; - } - const int warp_id = threadIdx.x / THREADS_PER_WARP; loader0.load(id_x, n); loader1.load(id_x, n); @@ -423,21 +433,20 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.separate()[i] = static_cast(static_cast(temp)); } storer.store(id_x, n); - - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); - } + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -482,9 +491,17 @@ template __launch_bounds__(unary_kernel_threads) __global__ void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; @@ -507,12 +524,35 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgate = grad_val * Activation(gelu_in, p); + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(after_dgelu), max); + after_dgelu = after_dgelu * s; + max = fmaxf(fabsf(after_dgate), max); + after_dgate = after_dgate * s; + } + storer0.separate()[i] = static_cast(after_dgelu); storer1.separate()[i] = static_cast(after_dgate); } storer0.store(id_x, n); storer1.store(id_x, n); } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } } template void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, - OutputType *output, const size_t m, const size_t n, - const Param &p, cudaStream_t stream) { + OutputType *output, const fp32 *scale, fp32 *amax, + fp32 *scale_inv, const size_t m, const size_t n, const Param &p, + cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -532,18 +573,19 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { case Alignment::SAME_ALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> - <<>>(grad, input, output, m, n, p, n); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 6267baf19e..63ce369892 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -819,6 +819,21 @@ __device__ __forceinline__ float warp_reduce_max(const float m) { return tmp; } +__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero); + return val_tmp; +} + template __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { __shared__ float staging[num_warps]; @@ -837,6 +852,29 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war return result; } +/** + * Max reduction in subwarps + * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors. + * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate, + * thus splitting the warp into 4x smaller subwarps 8-thread width. + * 'Butterfly' reduction is used inside subwarps. + */ +template +__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); + return val_tmp; +} + // Works only on positive values __device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) { atomicMax(reinterpret_cast(addr), __float_as_int(value)); @@ -857,6 +895,79 @@ __device__ __forceinline__ void reciprocal(float *value_inv, const float *value_inv = __frcp_rn(value); } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +using e8m0_t = uint8_t; + +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; + +template +struct Numeric_Traits; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 8; + static constexpr double maxNorm = 448; +}; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 15; + static constexpr double maxNorm = 57344; +}; + +template +struct Quantized_Limits { + static constexpr int max_unbiased_exponent = Numeric_Traits::maxUnbiasedExponent; + static constexpr float max_norm = Numeric_Traits::maxNorm; + static constexpr float max_norm_rcp = 1.0 / max_norm; + static constexpr float emax = 1 << max_unbiased_exponent; + static constexpr float emax_rcp = 1.0 / emax; +}; + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 41a6846a7c..a5457fa032 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -6,6 +6,7 @@ #include "transformer_engine/activation.h" #include "extensions.h" +#include "transformer_engine/cast.h" #include "transformer_engine/transpose.h" #include "xla/ffi/api/c_api.h" @@ -332,18 +333,27 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias_dgelu chooses its internal impl based + // on what pointers are allocated, e.g. whether to output with + // column-wise data. However, we don't have access to any allocated + // buffers in this function. We pass a dummy pointer as a + // workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto dact_input_tensor = + TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); + auto output_tensor = TensorWrapper(); + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); TensorWrapper dummy_workspace; // For now, all dbias_dact(-s) have the same workspace size - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); @@ -384,37 +394,32 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); switch (act_enum) { case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -468,37 +473,32 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto act_type = static_cast(act_enum); switch (act_type) { case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -555,29 +555,29 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); switch (act_enum) { case NVTE_Activation_Type::GEGLU: nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::REGLU: nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::QGEGLU: nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::SREGLU: nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -622,30 +622,30 @@ Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto act_type = static_cast(act_enum); switch (act_type) { case NVTE_Activation_Type::GEGLU: nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::REGLU: nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::QGEGLU: nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::SREGLU: nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 569dfd3baa..71d1456287 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -25,7 +25,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -48,7 +48,7 @@ Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a auto input_tensor = TensorWrapper(input, shape, in_dtype); auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); return ffi_with_cuda_error_check(); } @@ -76,7 +76,7 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -96,7 +96,7 @@ Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 516930c529..af347f45b2 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/transpose.h" #include "extensions.h" +#include "transformer_engine/cast.h" #include "xla/ffi/api/ffi.h" namespace transformer_engine { @@ -89,13 +90,12 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size auto input_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto input_cast_tensor = + auto output_tensor = TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype, - amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(input_cast_trans, desc.out_dtype, input_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); - nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), - stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -131,11 +131,11 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); + + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); - nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - stream); return ffi_with_cuda_error_check(); } @@ -159,15 +159,22 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias chooses its internal impl based on what + // pointers are allocated, e.g. whether to output with column-wise + // data. However, we don't have access to any allocated buffers in + // this function. We pass a dummy pointer as a workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); TensorWrapper dummy_workspace; - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); @@ -203,14 +210,14 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + workspace.data(), stream); } Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -253,13 +260,13 @@ Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buf auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + workspace_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index e7ee350b46..f2dbd3b131 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -354,11 +354,6 @@ def fp8_autocast( assert ( fp8_recipe.scaling_factor_compute_algo is None ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." - assert fp8_recipe.override_linear_precision == ( - False, - False, - False, - ), "DelayedScaling override_linear_precision isn't supported by TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." if mesh_resource is None: diff --git a/transformer_engine/paddle/MANIFEST.in b/transformer_engine/paddle/MANIFEST.in deleted file mode 100644 index 0c814f95da..0000000000 --- a/transformer_engine/paddle/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include build_tools *.* -recursive-include common_headers *.* -recursive-include csrc *.* diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py deleted file mode 100644 index 583c4a7a7a..0000000000 --- a/transformer_engine/paddle/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Transformer Engine bindings for Paddle""" - -# pylint: disable=wrong-import-position,wrong-import-order - -import logging -from importlib.metadata import version - -from transformer_engine.common import is_package_installed - - -def _load_library(): - """Load shared library with Transformer Engine C extensions""" - module_name = "transformer_engine_paddle" - - if is_package_installed(module_name): - assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." - assert is_package_installed( - "transformer_engine_cu12" - ), "Could not find `transformer-engine-cu12`." - assert ( - version(module_name) - == version("transformer-engine") - == version("transformer-engine-cu12") - ), ( - "TransformerEngine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[paddle]==VERSION'" - ) - - if is_package_installed("transformer-engine-cu12"): - if not is_package_installed(module_name): - logging.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[paddle]==VERSION'", - module_name, - ) - - from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import - - -_load_library() -from .fp8 import fp8_autocast -from .layer import ( - Linear, - LayerNorm, - LayerNormLinear, - LayerNormMLP, - FusedScaleMaskSoftmax, - DotProductAttention, - MultiHeadAttention, - TransformerLayer, - RotaryPositionEmbedding, -) -from .recompute import recompute diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py deleted file mode 100644 index dee8a70c38..0000000000 --- a/transformer_engine/paddle/constants.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Constants""" - -from enum import Enum - -import paddle - -from transformer_engine import transformer_engine_paddle as tex - - -class FP8FwdTensors(Enum): - """Used as named indices on the `scale`, `scale_inv`, - and `amax` tensors in the `FP8TensorMeta` class.""" - - GEMM1_INPUT = 0 - GEMM1_WEIGHT = 1 - GEMM1_OUTPUT = 2 - GEMM2_INPUT = 3 - GEMM2_WEIGHT = 4 - GEMM2_OUTPUT = 5 - - -class FP8BwdTensors(Enum): - """Used as named indices on the `scale`, `scale_inv`, - and `amax` tensors in the `FP8TensorMeta` class.""" - - GRAD_OUTPUT1 = 0 - GRAD_INPUT1 = 1 - GRAD_OUTPUT2 = 2 - GRAD_INPUT2 = 3 - - -""" -Map from paddle dtype to TE dtype -""" -TE_DType = { - paddle.uint8: tex.DType.kByte, - paddle.int32: tex.DType.kInt32, - paddle.float32: tex.DType.kFloat32, - paddle.float16: tex.DType.kFloat16, - paddle.bfloat16: tex.DType.kBFloat16, -} - -AttnMaskTypes = ("causal", "padding", "no_mask") - -AttnTypes = ("self", "cross") - -LayerTypes = ("encoder", "decoder") - -GemmParallelModes = ("row", "column", None) - -dist_group_type = paddle.distributed.collective.Group - -RecomputeFunctionNames = ("unpack", "backward") - -AttnBiasType = { - "no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS, - "pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, - "post_scale_bias": tex.NVTE_Bias_Type.NVTE_POST_SCALE_BIAS, -} - -AttnMaskType = { - "no_mask": tex.NVTE_Mask_Type.NVTE_NO_MASK, - "padding": tex.NVTE_Mask_Type.NVTE_PADDING_MASK, - "causal": tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, -} - -FusedAttnBackend = { - "F16_max512_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, - "F16_arbitrary_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - "No_Backend": tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, -} diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py deleted file mode 100644 index 293c62a2fd..0000000000 --- a/transformer_engine/paddle/cpp_extensions.py +++ /dev/null @@ -1,1199 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""TE FP8 extensions and GEMMs""" - -import math -from typing import Optional, Tuple, Union -import paddle -import paddle.nn.functional as F -from transformer_engine import transformer_engine_paddle as tex -from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors -from .fp8 import FP8TensorMeta, get_global_fp8_state - -BACKEND_F16m512_THREADS_PER_CTA = 128 -BACKEND_F16arb_ELTS_PER_THREADS = 16 - - -def gemm( - A: paddle.Tensor, - B: paddle.Tensor, - dtype: paddle.dtype, - workspace: paddle.Tensor, - gelu: bool = False, - gelu_input: Optional[paddle.Tensor] = None, - grad: bool = False, - accumulate: bool = False, - layout: str = "TN", - out: Optional[paddle.Tensor] = None, - out_dtype: Optional[paddle.dtype] = None, - bias: Optional[paddle.Tensor] = None, - use_bias: bool = False, -) -> Tuple[Union[paddle.Tensor, None], ...]: - """Non FP8 GEMM.""" - - assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." - transa = layout[0] == "T" - transb = layout[1] == "T" - - if out is None: - if accumulate: - out = paddle.zeros( - shape=[ - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - ], - dtype=out_dtype if out_dtype is not None else dtype, - ) - else: - out = paddle.empty( - shape=[ - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - ], - dtype=out_dtype if out_dtype is not None else dtype, - ) - - if gelu and not grad: - gelu_input = paddle.empty_like(out, dtype=dtype) - elif not gelu: - gelu_input = None - - if grad and use_bias: - grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype) - else: - grad_bias = None - - bias = bias if use_bias else None - - assert ( - A.dtype == dtype and B.dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out.dtype] - if use_bias: - bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] - else: - bias_dtype = output_dtype - - tex.te_gemm( - A, - None, - B, - None, - grad_bias if grad else bias, - out, - None, # out_scale - None, # out_amax - gelu_input, - workspace, - 0, # A_index - 0, # B_index - 0, # D_index - int(input_dtype), - int(input_dtype), - int(output_dtype), - int(bias_dtype), - transa, - transb, - grad, - workspace.shape[0], - accumulate, - False, # use_split_accumulator - 0, # math_sm_count - ) - - return out, grad_bias, gelu_input - - -def fp8_gemm( - A: paddle.Tensor, - A_scale_inv: paddle.Tensor, - A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - A_dtype: tex.DType, - B: paddle.Tensor, - B_scale_inv: paddle.Tensor, - B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - B_dtype: tex.DType, - out_dtype: paddle.dtype, - workspace: paddle.Tensor, - gelu: bool = False, - accumulate: bool = False, - out: Optional[paddle.Tensor] = None, - out_index=None, - fp8_meta_tensor: FP8TensorMeta = None, - bias: Optional[paddle.Tensor] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> paddle.Tensor: - """TN layout GEMM with fp8 inputs.""" - - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_index is not None - - if out is None: - if accumulate: - out = paddle.zeros( - shape=[ - B.shape[0], - A.shape[0], - ], - dtype=out_dtype, - ) - else: - out = paddle.empty( - shape=[ - B.shape[0], - A.shape[0], - ], - dtype=out_dtype, - ) - - # Use bfloat16 as default bias_dtype - bias_dtype = paddle.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = paddle.empty_like(out, dtype=bias_dtype) - else: - gelu_input = None - bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype - - tex.te_gemm( - A, - A_scale_inv, - B, - B_scale_inv, - bias if use_bias else None, - out, - None if out_index is None else fp8_meta_tensor.scale, - None if out_index is None else fp8_meta_tensor.amax_history, - gelu_input, # this is pre_gelu_out - workspace, - A_fp8_tensor.value, - B_fp8_tensor.value, - 0 if out_index is None else out_index, - int(A_dtype), - int(B_dtype), - int(out_dtype), - int(bias_dtype), - True, # transa - False, # transb - False, # grad - workspace.shape[0], - accumulate, - use_split_accumulator, - 0, # math_sm_count - ) - - return out, gelu_input - - -def cast_to_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - out: Optional[paddle.Tensor] = None, -) -> paddle.Tensor: - """Cast input to FP8""" - if out is None: - out = paddle.empty( - shape=inp.shape, - dtype=paddle.uint8, - ) - else: - assert out.shape == inp.shape, "Output shape does not match input shape." - assert out.dtype == paddle.uint8, "Output should be of uint8 dtype." - - tex.cast_to_fp8( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - return out - - -def cast_from_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - itype: tex.DType, - otype: tex.DType, -) -> paddle.Tensor: - """Cast input from FP8""" - return tex.cast_from_fp8( - inp, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(itype), - int(otype), - ) - - -def transpose( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Transpose input""" - return tex.te_transpose( - inp, - int(otype), - ) - - -def cast_transpose( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - cast_out: Optional[paddle.Tensor] = None, - transpose_out: Optional[paddle.Tensor] = None, -) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]: - """Cast + Transpose with FP8 output""" - if cast_out is None: - cast_out = paddle.empty( - shape=inp.shape, - dtype=paddle.uint8, - ) - else: - assert cast_out.shape == inp.shape, "cast_out shape does not match input shape." - assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype." - - if transpose_out is None: - transpose_out = paddle.empty( - shape=[inp.shape[1], inp.shape[0]], - dtype=paddle.uint8, - ) - else: - assert transpose_out.shape == [ - inp.shape[1], - inp.shape[0], - ], "Transposed output shape does not match input shape." - assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype." - - tex.te_cast_transpose( - inp, - fp8_meta_tensor.scale, - cast_out, - transpose_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return cast_out, transpose_out - - -def cast_transpose_bgrad( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]: - """Fused Cast + Transpose + Bias Grad""" - grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return grad_bias, cast_out, transpose_out - - -def te_gelu( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Non FP8 GELU""" - return tex.te_gelu( - inp, - int(otype), - ) - - -def gelu_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> paddle.Tensor: - """GELU + FP8 cast""" - out, _, _ = tex.te_gelu_fp8( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return out - - -def swiglu( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Non FP8 SWIGLU""" - return tex.te_swiglu( - inp, - int(otype), - ) - - -def swiglu_pd( - inp: paddle.Tensor, -) -> paddle.Tensor: - """Native SWIGLU""" - gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1) - out = F.silu(gate_out) * up_out - return out - - -def swiglu_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> paddle.Tensor: - """SWIGLU + FP8 cast""" - out, _, _ = tex.te_swiglu_fp8( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return out - - -def dswiglu( - grad_output: paddle.Tensor, - swiglu_input: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """dSWIGLU""" - return tex.te_dswiglu( - grad_output, - swiglu_input, - int(otype), - ) - - -def dgelu_cast_transpose_bgrad_fp8( - grad_output: paddle.Tensor, - gelu_input: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """ - Fused dgelu + cast / transpose / reduce the result of - the GELU backward along the first dimension - """ - cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu( - grad_output, - gelu_input, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return cast_dgelu, transpose_dgelu, dbias - - -def layernorm_fwd_fp8( - inp: paddle.Tensor, - weight: paddle.Tensor, - bias: paddle.Tensor, - eps: float, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """LayerNorm with FP8 output""" - out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8( - inp, - weight, - bias, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - eps, - fp8_tensor.value, - int(otype), - sm_margin, - zero_centered_gamma, - ) - return out, mu, rsigma - - -def layernorm_fwd( - inp: paddle.Tensor, - weight: paddle.Tensor, - bias: paddle.Tensor, - eps: float, - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 LayerNorm forward""" - return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma) - - -def layernorm_bwd( - dz: paddle.Tensor, - x: paddle.Tensor, - mu: paddle.Tensor, - rsigma: paddle.Tensor, - gamma: paddle.Tensor, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 LayerNorm backward""" - return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) - - -def rmsnorm_fwd( - inp: paddle.Tensor, - weight: paddle.Tensor, - eps: float, - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 RMSNorm forward""" - return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma) - - -def rmsnorm_fwd_fp8( - inp: paddle.Tensor, - weight: paddle.Tensor, - eps: float, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """RMSNorm with FP8 output""" - out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8( - inp, - weight, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - eps, - fp8_tensor.value, - int(otype), - sm_margin, - zero_centered_gamma, - ) - return out, rsigma - - -def rmsnorm_bwd( - dz: paddle.Tensor, - x: paddle.Tensor, - rsigma: paddle.Tensor, - gamma: paddle.Tensor, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 RMSNorm backward""" - return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) - - -def mask_to_cu_seqlens( - mask: paddle.Tensor, - need_kv: bool = False, -) -> paddle.Tensor: - """Convert mask to cu_seqlens""" - # mask shape: [b, 1, s_q, s_kv] - if get_global_fp8_state().is_cudagraph_enabled(): - raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.") - q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3] - q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) - q_cu_seqlens[0] = 0 - kv_cu_seqlens = None - if need_kv: - kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) - kv_cu_seqlens[0] = 0 - tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv) - return q_cu_seqlens, kv_cu_seqlens - - -def fused_attn_fwd_qkvpacked( - qkv: paddle.Tensor, - cu_seqlens: paddle.Tensor, - is_training: bool, - max_seqlen: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for packed QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - - b = cu_seqlens.shape[0] - 1 - total_seqs = qkv.shape[0] * qkv.shape[1] - h = qkv.shape[3] - d = qkv.shape[4] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen, - max_seqlen, - ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd_qkvpacked( - qkv, - cu_seqlens, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - total_seqs, - max_seqlen, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - return out, softmax_aux, rng_state - - -def fused_attn_bwd_qkvpacked( - qkv: paddle.Tensor, - cu_seqlens: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - - b = cu_seqlens.shape[0] - 1 - total_seqs = qkv.shape[0] * qkv.shape[1] - h = qkv.shape[3] - d = qkv.shape[4] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) - else: - dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype) - - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - else: - dbias = None - # execute kernel - dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked( - qkv, - cu_seqlens, - o, - d_o, - softmax_aux, - dqkv, - dbias, - rng_state, - b, - h, - d, - total_seqs, - max_seqlen, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - - return dqkv, dbias - - -def fused_attn_fwd_kvpacked( - q: paddle.Tensor, - kv: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bs2hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - - b = cu_seqlens_q.shape[0] - 1 - total_seqs_q = q.shape[0] * q.shape[1] - total_seqs_kv = kv.shape[0] * kv.shape[1] - h = q.shape[2] - d = q.shape[3] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen_q, - max_seqlen_kv, - ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - total_seqs_q, - total_seqs_kv, - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - - return out, softmax_aux, rng_state - - -def fused_attn_bwd_kvpacked( - q: paddle.Tensor, - kv: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bs2hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - - b = cu_seqlens_q.shape[0] - 1 - total_seqs_q = q.shape[0] * q.shape[1] - total_seqs_kv = kv.shape[0] * kv.shape[1] - h = q.shape[2] - d = q.shape[3] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) - dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) - else: - dq = paddle.empty(shape=q.shape, dtype=q.dtype) - dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype) - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = None - # execute kernel - tex.te_fused_attn_bwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - o, - d_o, - softmax_aux, - dq, - dkv, - dbias, - rng_state, - b, - h, - d, - total_seqs_q, - total_seqs_kv, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - return dq, dkv, dbias - - -def fused_attn_fwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bshd_bshd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for unpacked QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert ( - qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." - b = cu_seqlens_q.shape[0] - 1 - - h = q.shape[-2] - d = q.shape[-1] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen_q, - max_seqlen_kv, - ], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." - assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - return out, softmax_aux, rng_state - - -def fused_attn_bwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bshd_bshd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert ( - qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." - - b = cu_seqlens_q.shape[0] - 1 - h = q.shape[-2] - d = q.shape[-1] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) - dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) - dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype) - else: - dq = paddle.empty(shape=q.shape, dtype=q.dtype) - dk = paddle.empty(shape=k.shape, dtype=k.dtype) - dv = paddle.empty(shape=v.shape, dtype=v.dtype) - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = None - # execute kernel - tex.te_fused_attn_bwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - o, - d_o, - softmax_aux, - dq, - dk, - dv, - dbias, - rng_state, - b, - h, - d, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - return dq, dk, dv, dbias - - -def scaled_softmax_forward( - inp: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled softmax forward""" - return tex.te_scaled_softmax_forward(inp, scale_factor) - - -def scaled_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled softmax backward""" - tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad - - -def scaled_masked_softmax_forward( - inp: paddle.Tensor, - mask: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled masked softmax forward""" - - return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor) - - -def scaled_masked_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled masked softmax backward""" - tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad - - -def scaled_upper_triang_masked_softmax_forward( - inp: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled upper triang masked softmax forward""" - return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor) - - -def scaled_upper_triang_masked_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled upper triang masked softmax backward""" - tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad diff --git a/transformer_engine/paddle/csrc/common.cpp b/transformer_engine/paddle/csrc/common.cpp deleted file mode 100644 index d65fbb2b50..0000000000 --- a/transformer_engine/paddle/csrc/common.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -namespace transformer_engine { -namespace paddle_ext { - -TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, - const DType type) { - return TensorWrapper(const_cast(data_ptr), shape, type); -} - -TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) { - return TensorWrapper(data_ptr, shape, type); -} - -TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, - void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) { - return TensorWrapper(data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), - reinterpret_cast(scale_inv_ptr)); -} - -TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) { // NOLINT - return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype())); -} - -TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) { - return MakeNvteTensor(const_cast(tensor.data()), GetShapeArray(tensor), - Paddle2NvteDType(tensor.dtype())); -} - -paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, - bool init_to_zeros) { - auto size = shape.ndim; - if (size == 2 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 2) { - return paddle::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 1 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } else if (size == 1) { - return paddle::empty({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } - NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); -} - -// MHA utils -// convert QKV layout to enum -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) { - static const std::unordered_map layout_map = { - {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD}, - {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D}, - {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD}, - {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D}, - {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD}, - {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD}, - {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D}, - {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD}, - {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D}, - {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD}, - {"t3hd", NVTE_QKV_Layout::NVTE_T3HD}, - {"th3d", NVTE_QKV_Layout::NVTE_TH3D}, - {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD}, - {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D}, - {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD}, - }; - - auto it = layout_map.find(qkv_layout); - if (it != layout_map.end()) { - return it->second; - } else { - NVTE_ERROR("Invalid QKV layout string: " + qkv_layout); - } -} - -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h deleted file mode 100644 index 83737c0d21..0000000000 --- a/transformer_engine/paddle/csrc/common.h +++ /dev/null @@ -1,185 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "common/util/logging.h" -#include "paddle/extension.h" -#include "paddle/phi/backends/all_context.h" - -namespace transformer_engine { -namespace paddle_ext { -// Paddle Tensor Utils -template -inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) { - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); -} - -template -inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); -} - -template -inline const void *GetOptionalDataPtr(const paddle::optional &x, int64_t index) { - return x ? GetDataPtr(*x, index) : nullptr; -} - -template -inline void *GetOptionalDataPtr(paddle::optional &x, int64_t index) { // NOLINT - return x ? GetDataPtr(*x, index) : nullptr; -} - -inline const void *GetOptionalDataPtr(const paddle::optional &x) { - return x ? x->data() : nullptr; -} - -inline void *GetOptionalDataPtr(paddle::optional &x) { // NOLINT - return x ? x->data() : nullptr; -} - -inline std::vector GetShapeArray(const paddle::Tensor &x) { - std::vector shapes; - for (auto dim : x.shape()) { - shapes.push_back(static_cast(dim)); - } - return shapes; -} - -inline std::vector GetShapeArray(const paddle::optional &x) { - if (x) return GetShapeArray(x.get()); - return {0}; -} - -paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, - bool init_to_zeros = 0); - -// DType Utils -inline paddle::DataType Nvte2PaddleDType(DType t) { - switch (t) { - case DType::kInt32: - case DType::kFloat32: - return paddle::DataType::FLOAT32; - case DType::kFloat16: - return paddle::DataType::FLOAT16; - case DType::kBFloat16: - return paddle::DataType::BFLOAT16; - case DType::kByte: - case DType::kFloat8E4M3: - case DType::kFloat8E5M2: - return paddle::DataType::UINT8; - default: - NVTE_ERROR("Invalid type"); - } -} - -inline DType Paddle2NvteDType(paddle::DataType t) { - switch (t) { - case paddle::DataType::FLOAT16: - return DType::kFloat16; - case paddle::DataType::FLOAT32: - return DType::kFloat32; - case paddle::DataType::BFLOAT16: - return DType::kBFloat16; - case paddle::DataType::BOOL: - return DType::kByte; - case paddle::DataType::UINT8: - return DType::kByte; - case paddle::DataType::INT32: - return DType::kInt32; - case paddle::DataType::INT64: - return DType::kInt64; - default: - NVTE_ERROR("Invalid type"); - } -} - -inline DType Int2NvteDType(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } -} - -// get the fused attention backend -inline NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, head_dim, -1, -1); - return fused_attention_backend; -} - -// CUDA Utils -class cudaDevicePropertiesManager { - public: - static cudaDevicePropertiesManager &Instance() { - static thread_local cudaDevicePropertiesManager instance; - return instance; - } - - int GetMultiProcessorCount() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.multiProcessorCount; - } - - int GetMajor() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.major; - } - - private: - bool prop_queried_ = false; - cudaDeviceProp prop_; -}; - -// NVTE Tensor Utils -TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, - const DType type); -TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type); -TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, - void *amax_ptr, void *scale_ptr, void *scale_inv_ptr); -TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT -TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor); - -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout); - -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu deleted file mode 100644 index 460f4575e6..0000000000 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ /dev/null @@ -1,1776 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common.h" -#include "common/common.h" -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" - -namespace transformer_engine { -namespace paddle_ext { - -// convert bias type to enum -NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { - if (bias_type == "no_bias") { - return NVTE_Bias_Type::NVTE_NO_BIAS; - } else if (bias_type == "pre_scale_bias") { - return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; - } else if (bias_type == "post_scale_bias") { - return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; - } else { - NVTE_ERROR("Invalid bias type. \n"); - } -} - -// convert attn mask type to enum -NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { - if (mask_type == "padding") { - return NVTE_Mask_Type::NVTE_PADDING_MASK; - } else if (mask_type == "causal") { - return NVTE_Mask_Type::NVTE_CAUSAL_MASK; - } else if (mask_type == "no_mask") { - return NVTE_Mask_Type::NVTE_NO_MASK; - } else { - NVTE_ERROR("Invalid attention mask type. \n"); - } -} - -void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &output, // NOLINT - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), shape, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream()); -} - -std::vector cast_from_fp8(const paddle::Tensor &input, - const paddle::Tensor &scale_inv, int64_t index, - int64_t itype, int64_t otype) { - auto shape = GetShapeArray(input); - - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype))); - auto input_cu = - MakeNvteTensor(const_cast(input.data()), shape, Int2NvteDType(itype), nullptr, - nullptr, const_cast(GetDataPtr(scale_inv, index))); - auto output_cu = MakeNvteTensor(output); - - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_transpose(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(const_cast(input.data()), {M, N}, Int2NvteDType(otype)); - auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype)); - - nvte_transpose(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &output_cast, // NOLINT - paddle::Tensor &output_transpose, // NOLINT - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto input_cu = MakeNvteTensor(input); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - input.stream()); -} - -std::vector te_cast_transpose_bgrad(const paddle::Tensor &grad_output, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - auto grad_output_cast = - paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - auto grad_output_transpose = - paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - - auto input_cu = MakeNvteTensor(grad_output); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto output_transpose_cu = - MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - transformer_engine::TensorWrapper workspace; - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - -void te_gemm(const paddle::Tensor &A, const paddle::optional &A_scale_inverse, - const paddle::Tensor &B, const paddle::optional &B_scale_inverse, - const paddle::optional &bias, paddle::Tensor &D, // NOLINT - paddle::optional &D_scale, // NOLINT - paddle::optional &D_amax, // NOLINT - paddle::optional &pre_gelu_out, paddle::Tensor &workspace, // NOLINT - int64_t A_index, int64_t B_index, int64_t D_index, int64_t A_type, int64_t B_type, - int64_t D_type, int64_t bias_type, bool transa, bool transb, bool grad, - int64_t workspace_size, bool accumulate, bool use_split_accumulator, - int64_t math_sm_count) { - auto te_A = MakeNvteTensor( - const_cast(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(A_scale_inverse, A_index))); - auto te_B = MakeNvteTensor( - const_cast(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(B_scale_inverse, B_index))); - auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type), - GetOptionalDataPtr(D_amax, D_index), - GetOptionalDataPtr(D_scale, D_index), nullptr); - - auto te_bias = MakeNvteTensor(const_cast(GetOptionalDataPtr(bias)), GetShapeArray(bias), - Int2NvteDType(bias_type)); - - DType gelu_dtype = pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type); - auto te_pre_gelu_out = - MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype); - auto te_workspace = - MakeNvteTensor(workspace.data(), {static_cast(workspace_size)}, DType::kByte); - - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, A.stream()); -} - -std::vector te_gelu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_gelu(const paddle::Tensor &input, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype)); - - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_swiglu(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype)); - - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input, - int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype())); - auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype())); - auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype())); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output, - const paddle::Tensor &gelu_input, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - // DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - - auto dgelu = paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place()); - - auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(DType::kByte), grad_output.place()); - - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - - TensorWrapper workspace; - - auto gelu_input_cu = MakeNvteTensor(gelu_input); - auto input_cu = MakeNvteTensor(grad_output); - auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); - - return {dgelu, dgelu_transpose, grad_bias}; -} - -std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &bias, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - float eps, int64_t index, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - return {ln_out, mu, rsigma}; -} - -std::vector te_layernorm_fwd(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &bias, float eps, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - return {ln_out, mu, rsigma}; -} - -std::vector te_layernorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, - const paddle::Tensor &mu, const paddle::Tensor &rsigma, - const paddle::Tensor &gamma, int64_t sm_margin, - bool zero_centered_gamma) { - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - auto dbeta_cu = MakeNvteTensor(dbeta); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), - num_sm - sm_margin, zero_centered_gamma, dz.stream()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), - num_sm - sm_margin, zero_centered_gamma, dz.stream()); - - return {dx, dgamma, dbeta}; -} - -std::vector te_rmsnorm_fwd(const paddle::Tensor &input, - const paddle::Tensor &weight, float eps, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - return {ln_out, rsigma}; -} - -std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - float eps, int64_t index, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - return {ln_out, rsigma}; -} - -std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, - const paddle::Tensor &rsigma, - const paddle::Tensor &gamma, int64_t sm_margin, - bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, - dz.stream()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, - dz.stream()); - - return {dx, dgamma}; -} - -__global__ void set_rng_state( - [[maybe_unused]] unsigned int - identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516 - std::pair seed_offset, int64_t *rng_state_ptr) { - rng_state_ptr[0] = static_cast(seed_offset.first); - rng_state_ptr[1] = static_cast(seed_offset.second); -} - -void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_per_thread, - paddle::Tensor &rng_state) { - // extract random number generator seed and offset - const phi::DeviceContext *dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - - phi::Generator *gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - int64_t *rng_state_p = static_cast(rng_state.data()); -#if PADDLE_VERSION > 261 - auto state_index = gen_cuda->GetStateIndex(); - - auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { - // ensure the generator use correct state index - gen_cuda->SetStateIndex(state_index); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - params.As>(1) = seed_offset; - }; - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&set_rng_state); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); -#endif -} - -void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, - const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs, - int64_t max_seqlen, bool is_training, float attn_scale, - float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - UpdateRandomGenerator(QKV.place(), QKV.stream(), rng_elts_per_thread, rng_state); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -// fused attention BWD with packed QKV -void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, - const paddle::Tensor &O, const paddle::Tensor &dO, - const paddle::Tensor &softmax_aux, - paddle::Tensor &dQKV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs, - int64_t max_seqlen, float attn_scale, float p_dropout, - const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, int64_t qkv_type, - bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQKV = MakeNvteTensor(dQKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen), static_cast(max_seqlen)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens; - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - deterministic, workspace.data(), QKV.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - deterministic, workspace.data(), QKV.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_fwd_kvpacked( - const paddle::Tensor &Q, const paddle::Tensor &KV, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_kv, const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t total_seqs_kv, - int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, float attn_scale, - float p_dropout, const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor( - Q.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - te_KV = MakeNvteTensor( - KV.data(), - {static_cast(total_seqs_kv), 2, static_cast(h), static_cast(d)}, - qkv_dtype); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor( - O.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - UpdateRandomGenerator(Q.place(), Q.stream(), rng_elts_per_thread, rng_state); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -// fused attention BWD with packed KV -void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_kv, const paddle::Tensor &O, - const paddle::Tensor &dO, const paddle::Tensor &softmax_aux, - paddle::Tensor &dQ, // NOLINT - paddle::Tensor &dKV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, - int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, - float attn_scale, float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type, bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_KV = MakeNvteTensor(KV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dKV = MakeNvteTensor(dKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, deterministic, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, deterministic, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, - const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, - bool is_training, float attn_scale, float p_dropout, - const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, const int64_t qkv_type, - int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // extract random number generator seed and offset - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - auto stream = Q.stream(); - auto rng_state_p = static_cast(rng_state.data()); -#if PADDLE_VERSION > 261 - auto state_index = gen_cuda->GetStateIndex(); - auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { - // ensure the generator use correct state index - gen_cuda->SetStateIndex(state_index); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - params.As>(1) = seed_offset; - }; - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&set_rng_state); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); -#endif - - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, - p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, - p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, - const paddle::Tensor &O, const paddle::Tensor &dO, - const paddle::Tensor &softmax_aux, - paddle::Tensor &dQ, // NOLINT - paddle::Tensor &dK, // NOLINT - paddle::Tensor &dV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, - float attn_scale, float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type, bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dK = MakeNvteTensor(dK); - te_dV = MakeNvteTensor(dV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -std::vector te_scaled_softmax_forward(const paddle::Tensor &input, - float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - input.stream()); - - return {softmax_results}; -} - -void te_scaled_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); -} - -std::vector te_scaled_masked_softmax_forward(const paddle::Tensor &input, - const paddle::Tensor &mask, - float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int pad_batches = mask.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - NVTE_CHECK(pad_batches == 1 || pad_batches == batches); - NVTE_CHECK(mask.shape()[1] == 1); - NVTE_CHECK(mask.shape()[2] == query_seq_len); - NVTE_CHECK(mask.shape()[3] == key_seq_len); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto mask_cu = MakeNvteTensor(mask); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; -} - -void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); -} - -std::vector te_scaled_upper_triang_masked_softmax_forward( - const paddle::Tensor &input, float scale_factor) { - NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int attn_batches = input.shape()[0]; - const int seq_len = input.shape()[1]; - NVTE_CHECK(seq_len <= 2048); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; -} - -void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, - float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_upper_triang_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - softmax_results.stream()); -} - -__global__ void UpdateFP8MetaKernel( - [[maybe_unused]] unsigned int - identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516 - const float *amax, const float *rolled_amax_history, const bool *non_weight_mask, - float *amax_history, float *scale, float *scale_inv, bool update_weight_scale_inv, float margin, - float fp8_max, size_t history_numel, size_t amax_numel) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx >= history_numel) { - return; - } - - amax_history[idx] = rolled_amax_history[idx]; - - if (idx < amax_numel) { - float sf = (fp8_max / amax[idx]) / powf(2.0f, margin); - float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; - scale[idx] = scale_reg; - if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg; - amax_history[idx] = 0.0f; - } -} - -constexpr int BLOCK_SIZE = 512; - -void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, int64_t fp8_dtype, - float margin, const std::string &amax_compute) { - auto amax_history_ = MakeNvteTensor(amax_history); - auto scale_ = MakeNvteTensor(scale); - auto scale_inv_ = MakeNvteTensor(scale_inv); - const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask); - nvte_delayed_scaling_recipe_amax_and_scale_update( - amax_history_.data(), scale_.data(), scale_inv_.data(), non_weight_mask_.data(), - amax_history_.data(), scale_.data(), scale_inv_.data(), amax_compute.c_str(), - static_cast(fp8_dtype), margin, amax_history.stream()); -} - -void amax_and_scale_update_inplace_legacy( - paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, - const paddle::optional ¤t_step_id_tensor, bool update_weight_scale_inv, - bool fwd_update, float fp8_max, float margin, const std::string &amax_compute) { -#if PADDLE_VERSION > 261 - NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); - - paddle::Tensor amax; - - if (amax_compute == "max") { - amax = amax_history.max({0}); - } else { - amax = amax_history.slice(0, 1); - } - - const auto rolled_amax_history = amax_history.roll({-1}, {0}); - - auto amax_history_numel = amax_history.numel(); - auto amax_numel = amax.numel(); - size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE; - - const int *current_step_id_ptr = - reinterpret_cast(GetOptionalDataPtr(current_step_id_tensor)); - auto parameterSetter = [current_step_id_ptr, - fwd_update](phi::backends::gpu::gpuKernelParams ¶ms) { - if (fwd_update) { - int current_step_id = *current_step_id_ptr; - params.As(7) = (current_step_id == 0); - } - }; - - const float *amax_ptr = amax.data(); - const float *rolled_amax_history_ptr = rolled_amax_history.data(); - const bool *non_weight_mask_ptr = non_weight_mask.data(); - float *amax_history_ptr = amax_history.data(); - float *scale_ptr = scale.data(); - float *scale_inv_ptr = scale_inv.data(); - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&UpdateFP8MetaKernel); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - UpdateFP8MetaKernel<<>>( - id, amax_ptr, rolled_amax_history_ptr, non_weight_mask_ptr, amax_history_ptr, scale_ptr, - scale_inv_ptr, update_weight_scale_inv, margin, fp8_max, amax_history_numel, - amax_numel); - NVTE_CHECK_CUDA(cudaGetLastError()); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - NVTE_ERROR( - "amax_and_scale_update_inplace_legacy is not supported in old version of PaddlePaddle\n"); -#endif -} - -void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT - const paddle::Tensor &amax) { - // Copy amax to history[0] - NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), amax.numel() * SizeOf(amax.dtype()), - cudaMemcpyDeviceToDevice, amax.stream())); -} - -__global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel( - const bool *mask, int32_t *q_actual_seqlen, int32_t *kv_actual_seqlen, int q_seqlen, - int kv_seqlen, bool need_kv) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage q_smem; - __shared__ typename BlockReduce::TempStorage kv_smem; - unsigned int tid = threadIdx.x; - unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen; - - // load mask, convert to 1/0, do accumulation - int q = 0, kv = 0; - for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen; - q_idx += BLOCK_SIZE * kv_seqlen) { - q += (mask[q_idx + batch_offset] ? 0 : 1); - } - - if (need_kv) { - for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) { - kv += (mask[kv_idx + batch_offset] ? 0 : 1); - } - } - __syncthreads(); - - // compute cub::BlockReduce - int q_sum, kv_sum; - q_sum = BlockReduce(q_smem).Sum(q); - if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv); - - // write result for this block to global mem - if (tid == 0) { - q_actual_seqlen[blockIdx.x + 1] = q_sum; - if (need_kv) { - kv_actual_seqlen[blockIdx.x + 1] = kv_sum; - } - } -} - -__global__ __launch_bounds__(BLOCK_SIZE) void block_prefix_sum_inplace(int32_t *x, int n) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage smem; - // +1 to ignore the first element - int i = blockIdx.x * blockDim.x + threadIdx.x + 1; - - // load data - int32_t thread_data[1]; - thread_data[0] = i < n ? x[i] : 0; - __syncthreads(); - - // CUB block prefix sum - BlockScan(smem).InclusiveSum(thread_data, thread_data); - __syncthreads(); - - // write result - if (i < n) { - x[i] = thread_data[0]; - } -} - -void mask_to_cu_seqlens(const paddle::Tensor &mask, - paddle::Tensor &q_cu_seqlen, // NOLINT - paddle::optional &kv_cu_seqlen, // NOLINT - int q_seqlen, int kv_seqlen, bool need_kv) { - if (need_kv) { - NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr, - "kv_cu_seqlen must be provided when need_kv is true"); - } - mask_to_actual_seqlens_kernel<<>>( - mask.data(), q_cu_seqlen.data(), - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, need_kv); - // q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block - // to do prefix sum - NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail"); - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data(), - q_cu_seqlen.numel()); - if (need_kv) { - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>( - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel()); - } -} - -} // namespace paddle_ext -} // namespace transformer_engine - -PD_BUILD_OP(te_gemm) - .Inputs({"A", paddle::Optional("A_scale_inverse"), "B", paddle::Optional("B_scale_inverse"), - paddle::Optional("bias"), "_D", paddle::Optional("_D_scale"), - paddle::Optional("_D_amax"), paddle::Optional("_pre_gelu_out"), "_workspace"}) - .Outputs({"D", paddle::Optional("D_scale"), paddle::Optional("D_amax"), - paddle::Optional("pre_gelu_out"), "workspace"}) - .Attrs({"A_index: int64_t", "B_index: int64_t", "D_index: int64_t", "A_type: int64_t", - "B_type: int64_t", "D_type: int64_t", "bias_type: int64_t", "transa: bool", - "transb: bool", "grad: bool", "workspace_size: int64_t", "accumulate: bool", - "use_split_accumulator: bool", "math_sm_count: int64_t"}) - .SetInplaceMap({{"_D", "D"}, - {paddle::Optional("_D_scale"), paddle::Optional("D_scale")}, - {paddle::Optional("_D_amax"), paddle::Optional("D_amax")}, - {paddle::Optional("_pre_gelu_out"), paddle::Optional("pre_gelu_out")}, - {"_workspace", "workspace"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm)); - -PD_BUILD_OP(cast_to_fp8) - .Inputs({"Input", "Scale", "_Output", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetInplaceMap({{"_Output", "Output"}, {"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8)); - -PD_BUILD_OP(cast_from_fp8) - .Inputs({"Input", "ScaleInv"}) - .Outputs({"Output"}) - .Attrs({"index: int64_t", "itype: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_from_fp8)); - -PD_BUILD_OP(te_transpose) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose)); - -PD_BUILD_OP(te_cast_transpose) - .Inputs({"Input", "Scale", "_CastedOutput", "_TransposedOutput", "_Amax", "_ScaleInv"}) - .Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_CastedOutput", "CastedOutput"}, - {"_TransposedOutput", "TransposedOutput"}, - {"_Amax", "Amax"}, - {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose)); - -PD_BUILD_OP(te_cast_transpose_bgrad) - .Inputs({"GradOutput", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"dBias", "CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad)); - -PD_BUILD_OP(te_gelu_fp8) - .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu_fp8)); - -PD_BUILD_OP(te_gelu) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu)); - -PD_BUILD_OP(te_swiglu) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu)); - -PD_BUILD_OP(te_swiglu_fp8) - .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8)); - -PD_BUILD_OP(te_dswiglu) - .Inputs({"Grad", "Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu)); - -PD_BUILD_OP(te_cast_transpose_bgrad_dgelu) - .Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad_dgelu)); - -PD_BUILD_OP(te_layernorm_fwd_fp8) - .Inputs({"Input", "Weight", "Bias", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Mu", "Rsigma", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t", - "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd_fp8)); - -PD_BUILD_OP(te_layernorm_fwd) - .Inputs({"Input", "Weight", "Bias"}) - .Outputs({"Output", "Mu", "Rsigma"}) - .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd)); - -PD_BUILD_OP(te_layernorm_bwd) - .Inputs({"Dz", "X", "Mu", "Rsigma", "Gamma"}) - .Outputs({"Dx", "Dgamma", "Dbeta"}) - .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_bwd)); - -PD_BUILD_OP(te_rmsnorm_fwd) - .Inputs({"Input", "Weight"}) - .Outputs({"Output", "InvVariance"}) - .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd)); - -PD_BUILD_OP(te_rmsnorm_fwd_fp8) - .Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "InvVariance", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t", - "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8)); - -PD_BUILD_OP(te_rmsnorm_bwd) - .Inputs({"Dz", "X", "Rsigma", "Gamma"}) - .Outputs({"Dx", "Dgamma"}) - .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd)); - -PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) - .Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"), - "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", - "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked)); - -PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) - .Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias"), - "rng_state"}) - .Outputs({"dQKV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", - "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "deterministic: bool"}) - .SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked)); - -PD_BUILD_OP(te_fused_attn_fwd_kvpacked) - .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", - paddle::Optional("_softmax_aux"), "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", - "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", - "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked)); - -PD_BUILD_OP(te_fused_attn_bwd_kvpacked) - .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV", - paddle::Optional("_dBias"), "rng_state"}) - .Outputs({"dQ", "dKV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", - "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", - "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "deterministic: bool"}) - .SetInplaceMap({{"_dQ", "dQ"}, - {"_dKV", "dKV"}, - {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked)); - -PD_BUILD_OP(te_fused_attn_fwd) - .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", - paddle::Optional("_softmax_aux"), "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", - "max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float", - "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t", "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd)); - -PD_BUILD_OP(te_fused_attn_bwd) - .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK", - "_dV", paddle::Optional("_dBias"), "rng_state"}) - .Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", - "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", - "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t", "deterministic: bool"}) - .SetInplaceMap({{"_dQ", "dQ"}, - {"_dK", "dK"}, - {"_dV", "dV"}, - {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd)); - -PD_BUILD_OP(te_scaled_softmax_forward) - .Inputs({"input"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_forward)); - -PD_BUILD_OP(te_scaled_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_backward)); - -PD_BUILD_OP(te_scaled_masked_softmax_forward) - .Inputs({"input", "mask"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_forward)); - -PD_BUILD_OP(te_scaled_masked_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_backward)); - -PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_forward) - .Inputs({"input"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn( - PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_forward)); - -PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn( - PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); - -PD_BUILD_OP(amax_and_scale_update_inplace_legacy) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", - paddle::Optional("current_step_id_tensor")}) - .Outputs({"amax_history", "scale", "scale_inv"}) - .SetInplaceMap({{"_amax_history", "amax_history"}, - {"_scale", "scale"}, - {"_scale_inv", "scale_inv"}}) - .Attrs({"update_weight_scale_inv: bool", "fwd_update: bool", "fp8_max: float", "margin: float", - "amax_compute: std::string"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace_legacy)); - -PD_BUILD_OP(amax_and_scale_update_inplace) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"}) - .Outputs({"amax_history", "scale", "scale_inv"}) - .SetInplaceMap({{"_amax_history", "amax_history"}, - {"_scale", "scale"}, - {"_scale_inv", "scale_inv"}}) - .Attrs({"fp8_dtype: int64_t", "margin: float", "amax_compute: std::string"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace)); - -PD_BUILD_OP(update_latest_amax_history_inplace) - .Inputs({"_history", "amax"}) - .Outputs({"history"}) - .SetInplaceMap({{"_history", "history"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace)); - -PD_BUILD_OP(mask_to_cu_seqlens) - .Inputs({"mask", "_q_cu_seqlen", paddle::Optional("_kv_cu_seqlen")}) - .Outputs({"q_cu_seqlen", paddle::Optional("kv_cu_seqlen")}) - .Attrs({"q_seqlen: int", "kv_seqlen: int", "need_kv: bool"}) - .SetInplaceMap({{"_q_cu_seqlen", "q_cu_seqlen"}, - {paddle::Optional("_kv_cu_seqlen"), paddle::Optional("kv_cu_seqlen")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::mask_to_cu_seqlens)); diff --git a/transformer_engine/paddle/csrc/extensions.cpp b/transformer_engine/paddle/csrc/extensions.cpp deleted file mode 100644 index 44ad2e7511..0000000000 --- a/transformer_engine/paddle/csrc/extensions.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -namespace transformer_engine { -namespace paddle_ext { - -size_t get_cublasLt_version() { return cublasLtGetVersion(); } - -PYBIND11_MODULE(transformer_engine_paddle, m) { - // Misc - m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); - m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); - // Data structures - py::enum_(m, "DType", py::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); -} -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py deleted file mode 100644 index 0e91341b80..0000000000 --- a/transformer_engine/paddle/distributed.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Methods needed for distributed training.""" - -import os -import warnings -from contextlib import contextmanager -from typing import Any, Optional, Union, Tuple - -import paddle - -import paddle.distributed.fleet.base.topology as tp -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -from paddle.distributed.fleet.layers.mpu import mp_ops - -try: - # This feature is not supported as of Paddle 2.6. - from paddle.distributed.fleet.meta_parallel import ( - PipelineParallelMicroStepLocations, - register_global_pipeline_parallel_hook, - ) -except ImportError: - print("Cannot find register_global_pipeline_parallel_hook !") - register_global_pipeline_parallel_hook = None - -from .constants import dist_group_type - -_weight_split_axis = { - "transformer_engine": {"row": 1, "column": 0}, - "paddle": {"row": 0, "column": 1}, -} - - -def get_tp_group_and_world_size( - tp_group: Union[dist_group_type, None], enable_tp: bool = True -) -> Tuple[Union[dist_group_type, None], int]: - """Get TP group and world size using Fleet API""" - if not (paddle.distributed.is_initialized() and enable_tp): - return None, 1 - model_parallel_group = ( - tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group - ) - world_size = ( - tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() - if tp_group is None - else tp_group.nranks - ) - """ - When using TP, the NCCL communication needs to be scheduled - before the GEMM for a guaranteed overlap. From the host side - in TE, the comm calls are always launched first, but to ensure - that the GEMM isn't scheduled first, the environment variable - `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a - single channel. - """ - num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) - if num_cuda_work_queues != 1: - warnings.warn( - "To guarantee overlapping TP and SP collectives with the backward" - "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" - ) - - return model_parallel_group, world_size - - -def is_pp_enabled() -> bool: - """Check if pipeline parallel is enabled""" - if not paddle.distributed.is_initialized(): - return False - - return tp._HYBRID_PARALLEL_GROUP.get_pipe_parallel_world_size() > 1 - - -def register_pp_fwd_begin_hook(forward_begin_hook): - """Register the pp hook if register_global_pipeline_parallel_hook exist""" - if register_global_pipeline_parallel_hook is not None: - register_global_pipeline_parallel_hook( - PipelineParallelMicroStepLocations.FORWARD_BEGIN, forward_begin_hook - ) - - -@contextmanager -def track_rng_state(enable: bool, **kwargs) -> None: - """ - Applies get_rng_state_tracker().rng_state() to the context. - If not enabled, it does nothing. - """ - if enable: - with get_rng_state_tracker().rng_state(**kwargs): - yield - else: - yield - - -def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None: - """Set distributed attributes for the input tensor""" - tensor.is_distributed = is_parallel - if is_parallel: - tensor.split_axis = axis - - -def set_weight_tensor_dist_attr( - tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str -) -> None: - """Set distributed attributes for the weight tensor""" - if not is_parallel or parallel_mode is None: - return - set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode]) - - -def allreduce( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, -) -> Tuple[paddle.Tensor, Any]: - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_ - - # All-reduce. - if sync_op: - output = mp_ops._mp_allreduce( - input_, - group=tp_group, - use_calc_stream=True, - use_model_parallel=True, - ) - return output, None - - wait_handle = paddle.distributed.all_reduce( - input_, - op=paddle.distributed.ReduceOp.SUM, - group=tp_group, - sync_op=False, - ) - - output = input_ - - return output, wait_handle - - -def allgather( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, - axis: int = 0, -) -> Tuple[paddle.Tensor, Any]: - """All-gather the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_, None - - parallelism = tp_group.nranks - output_shape = input_.shape - output_shape[axis] = output_shape[axis] * parallelism - output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op) - if sync_op: - wait_handle.wait() - return output, None - return output, wait_handle - - -def reduce_scatter( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, -) -> [paddle.Tensor, Any]: - """Reduce-scatter the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_, None - - parallelism = tp_group.nranks - output_shape = input_.shape - assert input_.shape[0] % parallelism == 0, ( - f"Input sequence length {input_.shape[0]} can't be divided " - f"exactly by sequence parallelism {parallelism}" - ) - output_shape[0] = output_shape[0] // parallelism - output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = paddle.distributed.stream.reduce_scatter( - output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op - ) - if sync_op: - return output, None - return output, wait_handle - - -def identity( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, -) -> paddle.Tensor: - """ - Identity when forward. - Allreduce across model parallel group when backward. - """ - output = mp_ops._c_identity(input_, group=tp_group) - - return output - - -def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor): - """ - Set sequence_parallel attribute to input tensor. It is used for registering allreduce - hooks in PaddleNLP sequence parallel training. - """ - setattr(parameter, "sequence_parallel", True) diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py deleted file mode 100644 index 7313a81975..0000000000 --- a/transformer_engine/paddle/fp8.py +++ /dev/null @@ -1,370 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""FP8 utilities for TransformerEngine""" - -from contextlib import contextmanager -from typing import Tuple, Optional, Dict, Any, Union - -import numpy as np - -import paddle -from transformer_engine import transformer_engine_paddle as tex -from transformer_engine.common.recipe import DelayedScaling, Format - -from .constants import dist_group_type -from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer - -__all__ = ["fp8_autocast"] - -# FP8 support -_is_fp8_available = None -_reason_for_no_fp8 = "" - - -def _check_fp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - - # Check GPU arch - arch = paddle.device.cuda.get_device_capability() - if arch >= (9, 0): # hopper and above - return True, "" - if arch < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - - # Special handling for Ada - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if not paddle.version.cuda(): - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - if tuple(int(v) for v in paddle.version.cuda().split(".")) < (12, 1): - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - - -def is_fp8_available() -> Tuple[bool, str]: - """Return if fp8 support is available""" - global _is_fp8_available, _reason_for_no_fp8 - if _is_fp8_available is None: - _is_fp8_available, _reason_for_no_fp8 = _check_fp8_support() - return _is_fp8_available, _reason_for_no_fp8 - - -class FP8State: - """Stores FP8 state""" - - def __init__(self): - self._fp8_enabled = False - self._fp8_calibration = False - self._fp8_recipe = None - self._fp8_distributed_group = None - self._is_first_fp8_module = False - self._fp8_autocast_counter = 0 - self._fp8_autocast_depth = 0 - self._fp8_recompute_enabled = False - self._use_cudagraph = False - self._fp8_fwd_buffer = FP8MetaFwdBuffer() - self._fp8_bwd_buffer = FP8MetaBwdBuffer() - self._fp8_recompute_buffer = FP8RecomputeBuffer() - - def is_fp8_enabled(self) -> bool: - """Is FP8 enabled""" - return self._fp8_enabled - - def is_fp8_calibration(self) -> bool: - """Is FP8 calibration""" - return self._fp8_calibration - - def get_fp8_recipe(self) -> DelayedScaling: - """Return the fp8 recipe""" - return self._fp8_recipe - - @staticmethod - def get_default_fp8_recipe() -> DelayedScaling: - """FP8 recipe with default args.""" - return DelayedScaling() - - def get_autocast_id(self) -> int: - """Returns the number of times of entering the `fp8_autocast` context. - as a unique ID for different training steps.""" - return self._fp8_autocast_counter - - def is_first_fp8_module(self): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = self._is_first_fp8_module - self._is_first_fp8_module = False - return tmp - - def get_fp8_group(self) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return self._fp8_distributed_group - - def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer: - """Returns global fp8 forward buffer.""" - return self._fp8_fwd_buffer - - def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer: - """Returns global fp8 backward buffer.""" - return self._fp8_bwd_buffer - - def is_fp8_recompute_enabled(self) -> bool: - """Is FP8 recompute enabled""" - return self._fp8_recompute_enabled - - def get_fp8_recompute_buffer(self) -> FP8RecomputeBuffer: - """Returns global fp8 recompute buffer.""" - return self._fp8_recompute_buffer - - def is_cudagraph_enabled(self) -> bool: - """Is CUDAGraph enabled""" - return self._use_cudagraph - - def enable_cudagraph(self): - """Enable CUDA Graphs. Once CUDA Graphs are enabled, they cannot be disabled within the same execution context at current implementation.""" - self._use_cudagraph = True - self._fp8_fwd_buffer.enable_cudagraph() - self._fp8_bwd_buffer.enable_cudagraph() - if self._fp8_recompute_enabled: - raise RuntimeError("Currently, We do not allow recompute with cudagraph") - - def enter( - self, - enabled: bool, - calibrating: bool, - fp8_recipe: Optional[DelayedScaling], - fp8_group: Optional[dist_group_type], - ) -> None: - """Called when entering 'fp8_autocast'""" - self.saved_states = ( - self._fp8_enabled, - self._fp8_calibration, - self._fp8_recipe, - self._fp8_distributed_group, - self._is_first_fp8_module, - ) - - self._fp8_enabled = enabled - self._fp8_calibration = calibrating - self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - self._fp8_distributed_group = fp8_group - - if self._fp8_autocast_depth == 0: - self._is_first_fp8_module = True - self._fp8_autocast_counter += 1 - self._fp8_autocast_depth += 1 - - def exit(self): - """Called when exiting 'fp8_autocast'""" - # Restore saved states - ( - self._fp8_enabled, - self._fp8_calibration, - self._fp8_recipe, - self._fp8_distributed_group, - self._is_first_fp8_module, - ) = self.saved_states - - self._fp8_autocast_depth -= 1 - - if self._fp8_autocast_depth == 0: - self._fp8_fwd_buffer.finalize() - - -_global_fp8_state = FP8State() - - -def get_global_fp8_state() -> FP8State: - """Get global fp8 state""" - return _global_fp8_state - - -@contextmanager -def fp8_autocast( - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, - fp8_group: Optional[dist_group_type] = None, -) -> None: - """ - Context manager for FP8 usage. - - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `False` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` - recipe used for FP8 training. - fp8_group: paddle.distributed.collective.Group, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - try: - _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) - - if enabled: - fp8_available, reason_for_no_fp8 = is_fp8_available() - assert fp8_available, reason_for_no_fp8 - yield - finally: - _global_fp8_state.exit() - - -def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def amax_and_scale_update( - fp8_meta: Dict[str, Any], - fwd_update: bool, - update_weight_scale_inv: bool = True, - current_step_id_tensor: Optional[paddle.Tensor] = None, - use_cudagraph: bool = False, -) -> None: - """Updates fp8 amaxes/scales for fwd | bwd.""" - amax_compute = fp8_meta["recipe"].amax_compute_algo - sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo - fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" - fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" - - if not callable(amax_compute) and sf_compute is None: - non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask - - if use_cudagraph: - tex.amax_and_scale_update_inplace_legacy( - _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, - _scale=fp8_meta[fp8_meta_tensor_key].scale, - _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, - non_weight_mask=non_weight_mask, - current_step_id_tensor=current_step_id_tensor, - update_weight_scale_inv=update_weight_scale_inv, - fwd_update=fwd_update, - fp8_max=fp8_meta[fp8_max_key], - margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute, - ) - else: - if update_weight_scale_inv: - # we pass nullptr into kernel when we need to update_weight_scale_inv - non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace( - _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, - _scale=fp8_meta[fp8_meta_tensor_key].scale, - _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, - non_weight_mask=non_weight_mask, - fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)), - margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute, - ) - - else: - raise ValueError( - "We only support the fp8 recipe with 'max' or 'most_recent' " - "amax_compute_algo and default scaling_factor_compute_algo at this " - "moment." - ) - - -class FP8TensorMeta: - """Holds FP8 scaling and amax history for FP8 layers""" - - def __init__(self, is_forward: bool): - self.scale = paddle.Tensor() - self.scale_inv = paddle.Tensor() - self.amax_history = paddle.Tensor() - self.non_weight_mask = paddle.Tensor() - self.is_initialized = False - self.is_forward = is_forward - - def get_non_weight_mask(self, num_gemms: int): - """Needed for calculation of scale inverses to - preserve scale_inv when caching FP8 weights""" - if self.is_forward: - # [True, False, True]: -> [input, weight, output] - return paddle.to_tensor([True, False, True] * num_gemms) - # [True, True]: -> [grad_output, grad_input] - return paddle.to_tensor([True, True] * num_gemms) - - def prepare(self, num_gemms: int, amax_history_len: int) -> None: - """Prepare scales and amax tensors. It is called during fprop in each iteration. - If the meta tensors are not initialized yet, initialization is performed. If already - initialized, resize the meta tensors if amax_history_len has changed.""" - - if self.is_initialized: - # Handle changed amax history size. - curr_len = self.amax_history.shape[0] - num_fp8_tensors = self.amax_history.shape[1] - if amax_history_len < curr_len: - self.amax_history = self.amax_history[:amax_history_len] - elif amax_history_len > curr_len: - extra_rows = amax_history_len - curr_len - self.amax_history = paddle.concat( - [ - self.amax_history, - paddle.zeros((extra_rows, num_fp8_tensors), dtype="float32"), - ], - axis=0, - ) - return - - # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and - # 2 (grad_output and grad_input) for bwd - num_fp8_tensors = num_gemms * 3 if self.is_forward else num_gemms * 2 - - self.scale = paddle.ones(num_fp8_tensors, dtype="float32") - self.scale_inv = paddle.ones(num_fp8_tensors, dtype="float32") - self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype="float32") - self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) - - self.is_initialized = True - - def to_numpy(self): - """Convert FP8 meta tensors to numpy.""" - assert self.is_initialized, "FP8TensorMeta is not initialized yet." - return { - "scale": self.scale.numpy(), - "scale_inv": self.scale_inv.numpy(), - "amax_history": self.amax_history.numpy(), - } - - def from_numpy(self, data: Dict[str, np.array]): - """Set FP8 meta tensors from numpy""" - self.scale = paddle.to_tensor(data["scale"]) - self.scale_inv = paddle.to_tensor(data["scale_inv"]) - self.amax_history = paddle.to_tensor(data["amax_history"]) - - num_fp8_tensors = self.scale.shape[0] - num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2 - self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) - - self.is_initialized = True diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py deleted file mode 100644 index 06a9355e72..0000000000 --- a/transformer_engine/paddle/fp8_buffer.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""FP8 meta buffer for FP8 amax reduction""" - -from abc import ABC, abstractmethod -from collections import deque -from functools import partial -import os -from typing import Dict, Any, List, Union - -import numpy as np -import paddle -from transformer_engine import transformer_engine_paddle as tex - -from .constants import dist_group_type, RecomputeFunctionNames - - -class FP8MetaBufferBase(ABC): - """ - A global buffer that holds FP8 meta for reduction across trainers. - """ - - def __init__(self): - self._global_amax = {} - self._buffer_delete_key = None - self._amax_reduce_wait_func = None - self._dp_amax_reduce_interval = None - self._contiguous_amax = None - self._use_cudagraph = False - self._dp_amax_reduce_idx = 0 - - @staticmethod - @abstractmethod - def _get_meta_tensor_key(): - """Returns scaling key in `fp8_meta`.""" - - @staticmethod - @abstractmethod - def _get_buffer_position_key(): - """Returns module position key in `fp8_meta`.""" - - @staticmethod - @abstractmethod - def _get_autocast_key(): - """Returns autocast id key in `fp8_meta`.""" - - def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str: - """Return a key in `_global_amax` for the AMAX storage.""" - return f"AMAX_{fp8_meta[self._get_autocast_key()]}" - - def _execute_deletion(self) -> None: - """Delete the key from global amax buffer.""" - if self._buffer_delete_key is not None and self._buffer_delete_key in self._global_amax: - del self._global_amax[self._buffer_delete_key] - - def _wait_handle_and_split( - self, - contiguous_amax: paddle.Tensor, - chunk_sizes: List[int], - amax_buffer_key: str, - wait_handle: Union[bool, None], - ) -> None: - """Wait for amax reduction to finish and then copy reduced amax to buffer""" - if wait_handle is not None: - wait_handle.wait() - if self._use_cudagraph: - splited_list = list(contiguous_amax.split(chunk_sizes)) - for amax, split in zip(self._global_amax[amax_buffer_key], splited_list): - amax.copy_(split, False) - else: - self._global_amax[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) - - def _global_amax_reduction( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """Concatenate, reduce, and split amaxes in the global buffer.""" - - def _reduce_tensor_across_group_op_max(tensor, group, sync_op): - if paddle.distributed.is_initialized(): - wait_handle = paddle.distributed.all_reduce( - tensor, - op=paddle.distributed.ReduceOp.MAX, - group=group, - sync_op=sync_op, - ) - return wait_handle - return None - - amax_buffer_key = self._get_amax_buffer_key(fp8_meta) - # Key already deleted. - if amax_buffer_key not in self._global_amax: - return None - - # Reduce AMAX in DP-domain at an interval. - if self._dp_amax_reduce_interval is None: - self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) - - tp_amax_reduce = False - reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group. - if self._dp_amax_reduce_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval - - if tp_amax_reduce: - if tp_size > 1: - reduce_group = tp_group - else: - return None - - chunk_sizes = [x.shape[0] for x in self._global_amax[amax_buffer_key]] - if self._use_cudagraph: - # we need to ensure the _contiguous_amax is address-stable under cudagraph - if self._contiguous_amax is None: - self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key]) - else: - self._contiguous_amax.copy_( - paddle.concat(self._global_amax[amax_buffer_key]), False - ) - else: - self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key]) - - wait_handle = _reduce_tensor_across_group_op_max( - self._contiguous_amax, - reduce_group, - not fp8_meta["async_amax_reduction"], - ) - - if wait_handle is not None and self._use_cudagraph: - # we need to ensure record/wait does not cross the boundary of the graph - wait_handle.wait() - wait_handle = None - - return partial( - self._wait_handle_and_split, - self._contiguous_amax, - chunk_sizes, - amax_buffer_key, - wait_handle, - ) - - def add_amax(self, fp8_meta: Dict[str, Any]) -> None: - """Append `amax_history` to global buffer.""" - buffer_key = self._get_amax_buffer_key(fp8_meta) - fp8_meta_tensor_key = self._get_meta_tensor_key() - buffer_position_key = self._get_buffer_position_key() - - if buffer_key not in self._global_amax: - self._global_amax[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - else: - self._global_amax[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - - if buffer_position_key not in fp8_meta: - fp8_meta[buffer_position_key] = len(self._global_amax[buffer_key]) - 1 - - # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(self._global_amax[buffer_key]) - 1, ( - "Same module is being invoked more than once inside an `fp8_autocast` " - "region when using FP8 with amax reduction. This behavior is currently " - "unsupported. For more details and correct usage, please see " - "https://github.com/NVIDIA/TransformerEngine/pull/93." - ) - - def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: - """Populate current amax with the correct location from buffer.""" - fp8_meta_tensor_key = self._get_meta_tensor_key() - buffer_position_key = self._get_buffer_position_key() - if buffer_position_key not in fp8_meta: - return - - amax_buffer_key = self._get_amax_buffer_key(fp8_meta) - assert amax_buffer_key in self._global_amax, "TE internal error." - - # Copy amax to amax_history[0] - tex.update_latest_amax_history_inplace( - _history=fp8_meta[fp8_meta_tensor_key].amax_history, - amax=self._global_amax[amax_buffer_key][fp8_meta[buffer_position_key]], - ) - - def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: - """Delete this amax key from global buffer during autocast end.""" - if self._get_autocast_key() not in fp8_meta: - return - self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta) - - def get_amax_reduce_handle(self) -> Union[bool, None]: - """Return AMAX reduction wait handle.""" - return self._amax_reduce_handle - - def wait(self) -> None: - """Wait for reduced amax to be available in buffer.""" - if self._amax_reduce_wait_func is not None: - self._amax_reduce_wait_func() # pylint: disable=not-callable - self._amax_reduce_wait_func = None - - def to_numpy(self) -> Dict[str, List[np.array]]: - """Convert to numpy arrays""" - out = {} - for k, v in self._global_amax.items(): - out[k] = [tensor.numpy() for tensor in v] - return out - - def from_numpy(self, buffer: Dict[str, np.array]) -> None: - """Set buffer values from numpy arrays""" - for k, v in buffer.items(): - self._global_amax[k] = [paddle.to_tensor(arr) for arr in v] - - def enable_cudagraph(self): - """Enable CUDA Graphs.""" - self._use_cudagraph = True - - -class FP8MetaFwdBuffer(FP8MetaBufferBase): - """FP8Meta Buffer for forward""" - - @staticmethod - def _get_meta_tensor_key() -> str: - """Returns scaling key in `fp8_meta`.""" - return "scaling_fwd" - - @staticmethod - def _get_buffer_position_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "global_fp8_buffer_pos_fwd" - - @staticmethod - def _get_autocast_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "autocast_id_fwd" - - def set_for_amax_reduction( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """Sets up the function to call during autocast exit.""" - self._amax_global_reduce_func = partial( - self._global_amax_reduction, - fp8_meta, - tp_group, - tp_size, - ) - - def finalize(self) -> None: - """ - Called at FP8 autocast end. - Performs AMAX reduction and delete unused buffer entries. - """ - if hasattr(self, "_amax_global_reduce_func") and callable(self._amax_global_reduce_func): - self._amax_reduce_wait_func = self._amax_global_reduce_func() - self._execute_deletion() - - -class FP8MetaBwdBuffer(FP8MetaBufferBase): - """FP8Meta Buffer for backward""" - - @staticmethod - def _get_meta_tensor_key() -> str: - """Returns scaling key in `fp8_meta`.""" - return "scaling_bwd" - - @staticmethod - def _get_buffer_position_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "global_fp8_buffer_pos_bwd" - - @staticmethod - def _get_autocast_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "autocast_id_bwd" - - def finalize( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """ - Called at FP8 autocast end in backward. - Performs AMAX reduction and delete unused buffer entries. - """ - self._amax_reduce_wait_func = self._global_amax_reduction( - fp8_meta, tp_group, tp_size - ) # _wait_handle_and_split - self._execute_deletion() - - -class FP8RecomputeBuffer: - """Buffer used to hold FP8 meta tensors for recompute""" - - def __init__(self): - self._global_amax = [] - - @staticmethod - def get_buffer_position_key(): - """Returns the key (in fp8_meta) for recompute buffer position""" - return "recompute_buffer_pos" - - def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: - """Stash the scaling factors and amaxes for recompute""" - buffer_position_key = self.get_buffer_position_key() - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ] - - if buffer_position_key in fp8_meta: - self._global_amax[fp8_meta[buffer_position_key]].append(to_copy) - else: - self._global_amax.append(deque()) - self._global_amax[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(self._global_amax) - 1 - - def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: - """Switch to the previously saved scaling factors and amaxes""" - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale - fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = self.get_buffer_position_key() - stashed_fp8_meta = self._global_amax[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - assert "updated_amax_history_fwd" in fp8_meta, ( - "Recompute internal error." - " If you are not using recompute, please check if" - " the forward function is called from one of these functions: " - f"{RecomputeFunctionNames}. If so, consider change the function name " - "or set NVTE_DISABLE_RECOMPUTE=1." - ) - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] diff --git a/transformer_engine/paddle/layer/__init__.py b/transformer_engine/paddle/layer/__init__.py deleted file mode 100644 index 4d81ca231a..0000000000 --- a/transformer_engine/paddle/layer/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Layer level Paddle APIs""" - -from .attention import DotProductAttention, MultiHeadAttention, RotaryPositionEmbedding -from .layernorm import LayerNorm -from .layernorm_linear import LayerNormLinear -from .layernorm_mlp import LayerNormMLP -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax -from .transformer import TransformerLayer diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py deleted file mode 100644 index d3b0950dee..0000000000 --- a/transformer_engine/paddle/layer/attention.py +++ /dev/null @@ -1,1161 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Attntion API""" - -import math -import os -import warnings -from typing import Optional, Tuple, Union - -import paddle -import paddle.nn.functional as F - -try: - from paddle.incubate.nn.functional import fused_rotary_position_embedding -except ImportError: - fused_rotary_position_embedding = None -from transformer_engine import transformer_engine_paddle as tex - -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax -from ..constants import ( - AttnTypes, - TE_DType, - AttnBiasType, - AttnMaskType, - FusedAttnBackend, - dist_group_type, -) -from ..cpp_extensions import ( - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, - fused_attn_fwd, - fused_attn_bwd, - mask_to_cu_seqlens, -) -from ..distributed import get_tp_group_and_world_size, track_rng_state -from ..utils import attention_mask_func, divide -from ..recompute import recompute - -__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"] - - -def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: - """ - Used to repeat the key and value states for GQA. - The hidden states go from (batch, seqlen, num_gqa_groups, head_size) - to (batch, seqlen, num_heads, head_size) - """ - batch, seqlen, num_gqa_groups, head_size = hidden_states.shape - if n_rep == 1: - return hidden_states - - hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) - return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size]) - - -class RotaryPositionEmbedding(paddle.nn.Layer): - """ - Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. - """ - - def __init__( - self, - dim: int, - max_position_embeddings: int, - ): - """ - Parameters - ---------- - dim: int - rotary embedding dimension - max_position_embeddings: int - max_position_embeddings before position interpolation - """ - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.inv_freq = 1.0 / ( - 10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim) - ) - self._set_cos_sin_cache(seq_len=max_position_embeddings) - - def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - # [seq_len] - t = paddle.arange(seq_len, dtype="float32") - # [seq_len, dim/2] - freqs = paddle.einsum("i,j->ij", t, self.inv_freq) - # [seq_len, dim] - emb = paddle.concat([freqs, freqs], axis=-1) - # [1, seqlen, 1, dim] - self.cos_cached = emb.cos()[None, :, None, :] - self.sin_cached = emb.sin()[None, :, None, :] - - def forward(self, max_seq_len: int): - """ - Create rotary position embedding frequencies - - Parameters - ---------- - max_seq_len: int - sequence length of a sample - """ - cos = self.cos_cached[:, :, :max_seq_len, ...] - sin = self.sin_cached[:, :, :max_seq_len, ...] - return (cos, sin) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return paddle.concat([-x2, x1], axis=-1) # shape is the same as x - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): - """Applies rotary positional embedding to the input.""" - - if position_ids is None: - # Note: Only for LlamaForCausalLMPipe model pretraining - cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - else: - cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] - sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): - """Function for FusedAttention with packed QKV input""" - - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - attn_bias, - max_seqlen, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with packed QKV input""" - out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked( - qkv, - cu_seqlens, - is_training, - max_seqlen, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux) - ctx.max_seqlen = max_seqlen - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with packed QKV input""" - qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor() - dqkv, *rest = fused_attn_bwd_qkvpacked( - qkv, - cu_seqlens, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - - # if no_bias, return dqkv - if ctx.attn_bias_type == "no_bias": - return (dqkv, None) - # else, return (dqkv, dbias) - return (dqkv, None, rest[0]) - - -class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): - """Function for FusedAttention with packed KV input""" - - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - attn_bias, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with packed KV input""" - out, softmax_aux, rng_state = fused_attn_fwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - is_training, - max_seqlen_q, - max_seqlen_kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with packed KV input""" - q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dkv, *rest = fused_attn_bwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - - # if no_bias, return dq, dkv - if ctx.attn_bias_type == "no_bias": - return (dq, dkv, None, None) - # else, return (dq, dkv, dbias) - return (dq, dkv, None, None, rest[0]) - - -class FusedAttnFunc(paddle.autograd.PyLayer): - """Function for FusedAttention with separate Q, K, V tensors""" - - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - attn_bias, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with separate Q, K, V tensors""" - out, softmax_aux, rng_state = fused_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - is_training, - max_seqlen_q, - max_seqlen_kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with separate Q, K, V tensors""" - q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dk, dv, *rest = fused_attn_bwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - # if no_bias, return dq, dk, dv - if ctx.attn_bias_type == "no_bias": - return (dq, dk, dv, None, None) - # else, return (dq, dk, dv, dbias) - return (dq, dk, dv, None, None, rest[0]) - - -class DotProductAttention(paddle.nn.Layer): - """ - Allows the model to jointly attend to information from different - representation subspaces as described in the paper: - `Attention Is All You Need `_. - - .. note:: - - Argument :attr:`attention_mask` will be ignored in the `forward` call when - :attr:`attn_mask_type` is set to `"causal"`. - - .. warning:: - - Fused attention backward uses a non-deterministic algorithm when workspace - optimization is not enabled. To use a deterministic algorithm, set the - environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` - - Parameters - ---------- - num_attention_heads: int - number of attention heads in the transformer layer. - kv_channels: int - number of channels in the key and value tensors. - num_gqa_groups : Optional[int] = None - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the queries. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. - attention_type: {'self', 'cross'}, default = `self` - type of attention operation. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. - """ - - def __init__( - self, - num_attention_heads: int, - kv_channels: int, - num_gqa_groups: Optional[int] = None, - attention_dropout: float = 0.1, - attn_mask_type: str = "causal", - attention_type: str = "self", - tp_size: int = 1, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.attn_mask_type = attn_mask_type - self.attention_dropout = attention_dropout - self.attention_type = attention_type - self.qkv_layout = "bshd_bshd_bshd" - self.hidden_size_per_attention_head = kv_channels - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - self.tp_size = tp_size - self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) - self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups - - self.backend = backend - - self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) - - self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) - - # To use the workspace optimization path for determinism, please - # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0, - # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0. - cudnn_version = paddle.get_cudnn_version() - if 8905 <= cudnn_version < 9000: - if self.deterministic: - # workspace optimization path is deterministic - os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" - - # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT - # - unset: enables workspace optimization when required workspace is <= 256MB - # or when bias gradient needs to be computed - # - n: enables workspace optimization when required workspace is <= n bytes - # - -1: enables workspace optimization always - # - 0: disables workspace optimization always - if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ: - if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0": - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" - if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - - if not self.use_fused_attention and backend == "transformer_engine": - warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - self.backend = "paddle" - - if self.backend != "transformer_engine": - self.scale_mask_softmax = FusedScaleMaskSoftmax( - attn_mask_type, attention_mask_func, backend=self.backend - ) - - def forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - ) -> paddle.Tensor: - """ - Dot Product Attention Layer. - - .. note:: - - Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` - is set to `"causal"`. - - - Parameters - ---------- - query_layer : paddle.Tensor - Query tensor. - key_layer : paddle.Tensor - Key tensor. - value_layer : paddle.Tensor - Value tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. - core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to use the fast path to set output tensors to 0 or not. - """ - - backend = self.backend - - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" - assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" - - if backend == "transformer_engine": - max_s_q = query_layer.shape[1] - max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1] - self.fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype], - TE_DType[query_layer.dtype], - tex.get_nvte_qkv_layout(self.qkv_layout), - AttnBiasType[core_attention_bias_type], - AttnMaskType[self.attn_mask_type], - self.attention_dropout, - query_layer.shape[-2], - key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], - max_s_q, - max_s_kv, - query_layer.shape[-1], - ) - - is_backend_avail = self.fused_attention_backend in [ - FusedAttnBackend["F16_max512_seqlen"], - FusedAttnBackend["F16_arbitrary_seqlen"], - ] - if is_backend_avail and self.use_fused_attention: - return self._te_forward( - query_layer, - key_layer, - value_layer, - attention_mask, - core_attention_bias_type, - core_attention_bias, - set_zero, - ) - warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - backend = "paddle" - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.attn_mask_type, attention_mask_func, backend=backend - ) - if backend == "paddle": - if core_attention_bias_type != "no_bias": - warnings.warn( - "Paddle backend dot product attention does not support bias yet. " - "Bias will be ignored." - ) - return self._pd_forward(query_layer, key_layer, value_layer, attention_mask) - raise AttributeError(f"Backend {backend} is not supported.") - - def _te_forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - ) -> paddle.Tensor: - - if self.attention_type == "self": - # self attention - q: [b, s, h, d] kv: None - assert ( - len(query_layer.shape) == 4 - and len(key_layer.shape) == 4 - and len(value_layer.shape) == 4 - ), "q,k,v shape must be [b, s, h, d] for dot product self attention" - max_seqlen = query_layer.shape[1] - if self.attn_mask_type == "causal" or attention_mask is None: - cu_seqlens = paddle.arange( - 0, - (query_layer.shape[0] + 1) * query_layer.shape[1], - step=query_layer.shape[1], - dtype="int32", - ) - else: - cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False) - qkv_dtype = TE_DType[query_layer.dtype] - - output = FusedAttnFunc.apply( - query_layer, - key_layer, - value_layer, - cu_seqlens, - cu_seqlens, - core_attention_bias, - max_seqlen, - max_seqlen, - 1.0 / self.norm_factor, - qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, - self.qkv_layout, - core_attention_bias_type, - self.attn_mask_type, - self.training, - self.deterministic, - self.fused_attention_backend, - ) - elif self.attention_type == "cross": - # cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d] - assert ( - len(query_layer.shape) == 4 - and len(key_layer.shape) == 4 - and len(value_layer.shape) == 4 - ), ( - "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" - "for dot product cross attention" - ) - assert attention_mask is not None, "attention_mask must be provided for cross attention" - max_seqlen_q = query_layer.shape[1] - max_seqlen_kv = key_layer.shape[1] - cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) - qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFunc.apply( - query_layer, - key_layer, - value_layer, - cu_seqlens_q, - cu_seqlens_kv, - core_attention_bias, - max_seqlen_q, - max_seqlen_kv, - 1.0 / self.norm_factor, - qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, - self.qkv_layout, - core_attention_bias_type, - self.attn_mask_type, - self.training, - self.deterministic, - self.fused_attention_backend, - ) - else: - raise ValueError("attention_type must be one of ['self', 'cross']") - return output - - def _pd_forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - ) -> paddle.Tensor: - - q = query_layer - k = repeat_kv(key_layer, self.num_queries_per_key_value) - v = repeat_kv(value_layer, self.num_queries_per_key_value) - - q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) - k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) - v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) - - product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True) - attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None) - - if self.attention_dropout > 0: - attention_probs = F.dropout( - attention_probs, - self.attention_dropout, - training=self.training, - ) - - out = paddle.matmul(attention_probs, v) - out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d] - # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) - return out - - -class MultiHeadAttention(paddle.nn.Layer): - """ - Multi-head Attention (MHA), including Query, - Key, Value and Output projection. - - Parameters - ---------- - hidden_size: int - hidden size of the model. - num_attention_heads: int - number of attention heads. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - layernorm_epsilon: float, default = 1e-5 - epsilon to use in the layer norm operations. - weight_attr: Union[paddle.ParamAttr, None], default = `None` - paddle.ParamAttr object for the weight parameter. - bias_attr: Union[paddle.ParamAttr, None, bool], default = `None` - paddle.ParamAttr object for the bias parameter. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. - params_dtype: Optional[paddle.dtype], default = `None` - data type for the weights and biases. - return_layernorm_output: bool, default = `False` - whether to return the output of the layernorm operation. - input_layernorm: bool, default = `False` - whether to apply layernorm to the input. - attention_type: {'self', 'cross'}, default = `self` - type of attention operation. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - zero_centered_gamma: bool, default = `False` - whether to zero initialize the gamma of the layernorm operation. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. If set to 'paddle', a framework - only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel - whereas PROJ and FC2 is used as Row Parallel as described - `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - num_gqa_groups : int, default = `None` - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the querys. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - rng_state_name : str, default = `local_seed` - Controls the rng state used for dropout on attention probs. The - specified rng should be set different seeds for different TP ranks. - It will be ignored if `set_parallel_mode` is False. The specified - name should be registered through - `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() - .add(rng_state_name, seed)`. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - attention_dropout: float = 0.1, - layernorm_epsilon: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - return_layernorm_output: bool = False, - input_layernorm: bool = False, - attention_type: str = "self", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - num_gqa_groups: Optional[int] = None, - fuse_wgrad_accumulation: bool = False, - rng_state_name: str = "local_seed", - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.input_layernorm = input_layernorm - self.attention_type = attention_type - self.return_layernorm_output = return_layernorm_output - self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype - self.max_sequence_length = max_sequence_length - self.weight_attr = weight_attr - self.bias_attr = bias_attr - self.attn_mask_type = attn_mask_type - - assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" - - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.sequence_parallel = self.tensor_parallel and sequence_parallel - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - self.num_attention_heads = num_attention_heads - self.set_parallel_mode = set_parallel_mode - self.rng_state_name = rng_state_name - self.backend = backend - - self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) - self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - assert ( - self.num_attention_heads % self.num_gqa_groups == 0 - ), "The number of attention heads must be divisible by the number of GQA groups!" - assert ( - self.num_gqa_groups % self.tp_size == 0 - ), "The number of GQA groups must be divisible by tensor parallel size!" - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) - self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads) - qkv_parallel_mode = "column" if set_parallel_mode else None - - if self.attention_type == "self": - if self.input_layernorm: - self.layernorm_qkv = LayerNormLinear( - hidden_size, - hidden_size + 2 * self.hidden_size_kv, - eps=layernorm_epsilon, - weight_attr=self.weight_attr, - bias_attr=self.bias_attr, - return_layernorm_output=return_layernorm_output, - normalization=normalization, - zero_centered_gamma=zero_centered_gamma, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - else: - self.qkv = Linear( - hidden_size, - hidden_size + 2 * self.hidden_size_kv, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - else: # cross attention - if self.input_layernorm: - self.layernorm_query = LayerNormLinear( - hidden_size, - hidden_size, - eps=layernorm_epsilon, - weight_attr=self.weight_attr, - bias_attr=self.bias_attr, - return_layernorm_output=return_layernorm_output, - normalization=normalization, - zero_centered_gamma=zero_centered_gamma, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - else: - self.query_layer = Linear( - hidden_size, - hidden_size, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - self.key_value = Linear( - hidden_size, - 2 * self.hidden_size_kv, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - # Attention. - self.core_attention = DotProductAttention( - self.num_attention_heads, - self.hidden_size_per_attention_head, - self.num_gqa_groups, - attention_dropout, - attn_mask_type=attn_mask_type, - attention_type=self.attention_type, - tp_size=self.tp_size, - backend=self.backend, - ) - - # Linear - self.proj = Linear( - hidden_size, - hidden_size, - self.weight_attr, - self.bias_attr, - parallel_mode="row" if set_parallel_mode else None, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - encoder_output: Optional[paddle.Tensor] = None, - rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - recompute_core_attention: bool = False, - is_first_microbatch: Optional[bool] = None, - ) -> Tuple[Union[paddle.Tensor, None], ...]: - """ - MultiHeadAttention Layer. - - Parameters - ---------- - hidden_states : paddle.Tensor - Input tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. - encoder_output : Optional[paddle.Tensor], default = `None` - Output of the encoder layer. - rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied. - core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to use the fast path to set output tensors to 0 or not. - recompute_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - - if self.attn_mask_type != "causal" and attention_mask is not None: - assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - - input_dim = len(hidden_states.shape) - if input_dim == 2: - # hidden_states: [b * s_q, hidden_size] - # need to get max_seq_len from attention_mask - assert self.max_sequence_length is not None, "max_sequence_length must be provided" - max_seq_len = self.max_sequence_length - elif input_dim == 3: - # hidden_states: [b, s_q, hidden_size] - max_seq_len = hidden_states.shape[1] - else: - raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.") - - layernorm_output = None - if self.attention_type == "self": - if self.input_layernorm: - layernorm_qkv_outputs = self.layernorm_qkv( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - if self.return_layernorm_output: - mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs - else: - mixed_qkv_layer = layernorm_qkv_outputs - else: - mixed_qkv_layer = self.qkv( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - - num_queries_per_key_value = ( - self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition - ) - - # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d] - mixed_qkv_layer = mixed_qkv_layer.reshape( - shape=[ - -1, - max_seq_len, - (num_queries_per_key_value + 2), - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - # [b, s_q, (h/ng+2), ng, d] - # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d] - query_layer, key_layer, value_layer = paddle.split( - mixed_qkv_layer, - num_or_sections=(num_queries_per_key_value, 1, 1), - axis=2, - ) - - # query: -> [b, s, h, d] - # key, value: -> [b, s, ng, d] - query_layer, key_layer, value_layer = ( - x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) - for x in (query_layer, key_layer, value_layer) - ) - - else: # cross attention - mixed_kv_layer = self.key_value( - encoder_output, - is_first_microbatch=is_first_microbatch, - ) - # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] - mixed_kv_layer = mixed_kv_layer.reshape( - shape=[ - 0, - 0, - 2 * self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - # [b, s_kv, 2 * ng, head_size] - # --> 2 [b, s_kv, ng, head_size] - key_layer, value_layer = paddle.split( - mixed_kv_layer, - num_or_sections=2, - axis=2, - ) - - if self.input_layernorm: - layernorm_query_outputs = self.layernorm_query( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - if self.return_layernorm_output: - query_layer, layernorm_output = layernorm_query_outputs - else: - query_layer = layernorm_query_outputs - else: - query_layer = self.query_layer( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - - # [b, s, hidden_size] --> [b, s, h, d] - query_layer = query_layer.reshape( - shape=[ - -1, - max_seq_len, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - if fused_rotary_position_embedding is None: - query_layer, key_layer = apply_rotary_pos_emb( - query_layer, key_layer, q_pos_emb, k_pos_emb - ) - else: - query_layer, key_layer, _ = fused_rotary_position_embedding( - query_layer, - key_layer, - v=None, - sin=k_pos_emb, - cos=q_pos_emb, - position_ids=None, - use_neox_rotary_style=False, - ) - - with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): - if recompute_core_attention: - context_layer = recompute( - self.core_attention, - query_layer, - key_layer, - value_layer, - attention_mask, - core_attention_bias_type, - core_attention_bias, - set_zero, - use_reentrant=False, - ) - else: - context_layer = self.core_attention( - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) - - if input_dim == 3: - context_layer = paddle.reshape( - context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]] - ) - else: # input_dim == 2 - context_layer = paddle.reshape( - context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]] - ) - - # Output. [b, s, hidden] - attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch) - - if self.input_layernorm and self.return_layernorm_output: - return attention_output, layernorm_output - return attention_output diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py deleted file mode 100644 index a854bb70db..0000000000 --- a/transformer_engine/paddle/layer/base.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Base modules and utilities for TransformerEngine Paddle API""" - -from abc import ABC, abstractmethod -from contextlib import contextmanager -import os -import pickle -from typing import Generator, Dict, Tuple, Union, Any, List, Optional - -import numpy as np - -import paddle - -try: - from paddle.base import core - from paddle.base.framework import _dygraph_tracer -except ImportError: - from paddle.fluid import core - from paddle.fluid.framework import _dygraph_tracer - -from ..constants import FP8FwdTensors, FP8BwdTensors, dist_group_type -from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose -from ..fp8 import ( - FP8State, - FP8TensorMeta, - amax_and_scale_update, - get_global_fp8_state, - get_fp8_te_dtype, -) -from ..distributed import allgather, register_pp_fwd_begin_hook, is_pp_enabled -from ..profile import nvtx_range -from ..recompute import is_in_recompute_phase -from ..fp8_buffer import FP8RecomputeBuffer - -_2X_ACC_FPROP = False -_2X_ACC_DGRAD = True -_2X_ACC_WGRAD = True -_cublas_workspace = None - - -def get_cublas_workspace_size_bytes() -> None: - """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if paddle.device.cuda.get_device_capability()[0] >= 9: - return 33_554_432 - return 4_194_304 - - -def get_workspace() -> paddle.Tensor: - """Returns workspace for cublas.""" - global _cublas_workspace - if _cublas_workspace is None: - _cublas_workspace = paddle.empty( - [get_cublas_workspace_size_bytes()], - dtype="uint8", - ) - return _cublas_workspace - - -class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): - """Base TE Layer.""" - - def __init__(self) -> None: - super().__init__() - assert "gpu" in paddle.device.get_device(), "TransformerEngine needs CUDA." - self.fp8_initialized = False - self.fp8_enabled = False - self.fp8_calibration = False - self.fp8_meta = {} - self.fp8_meta["fp8_checkpoint"] = False - self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe() - self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True) - self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) - self.tp_group = None - self.tp_size = 1 - self.sequence_parallel = False - self.fp8_meta["autocast_id_fwd_stack"] = [] - self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) - ) - # weights that stored in fp16 would be cast into fp8 every first microstep - self.fp8_weights = [] - self.fp8_weight_cache = {} - self.registered_pp_start_callback = False - self.current_step_id = None - - def set_activation_dtype(self, inp: paddle.Tensor) -> None: - """Get activation data type for AMP.""" - tracer = _dygraph_tracer() - if tracer and tracer._amp_level != core.AmpLevel.O0: - # Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context - if tracer._amp_dtype == "float32": - self.activation_dtype = paddle.float32 - elif tracer._amp_dtype == "bfloat16": - self.activation_dtype = paddle.bfloat16 - elif tracer._amp_dtype == "float16": - self.activation_dtype = paddle.float16 - else: - raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.") - else: - # If not under paddle.amp.auto_cast, set activation_dtype to the input dtype. - # Also, make sure the parameters match the input dtype. - - # Skip the check if activation_dtype is already set and if activation_dtype - # matches input dtype. If they do not match, e.g, when user switch from AMP - # training to normal training, activation_dtype will still be updated. - if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: - return - - dtype = inp.dtype - - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - - self.activation_dtype = dtype - - # This routine is shared across FP8 and FP8_calibration paths so should not actually - # assume FP8 execution. - def fp8_init(self, num_gemms: int = 1) -> None: - """Initialize fp8 related metadata and tensors during fprop.""" - global_fp8_state = get_global_fp8_state() - self.fp8_enabled = global_fp8_state.is_fp8_enabled() - self.fp8_calibration = global_fp8_state.is_fp8_calibration() - self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration - - if self.fp8_enabled or self.fp8_calibration: - # FP8 init has already been run and recipe is the same, don't do anything. - if ( - self.fp8_initialized - and global_fp8_state.get_fp8_recipe() == self.fp8_meta["recipe"] - ): - return - - # Set FP8, recipe, and other FP8 metadata - self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe() - self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group() - - # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd - - # Allocate scales and amaxes - amax_history_len = self.fp8_meta["recipe"].amax_history_len - self.fp8_meta["scaling_fwd"].prepare(num_gemms, amax_history_len) - self.fp8_meta["scaling_bwd"].prepare(num_gemms, amax_history_len) - self.fp8_initialized = True - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False - return - - def set_fp8_weights(self) -> None: - """Initializes FP8 weights for the module""" - if not self.fp8_enabled: - return - - for i, weight in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - if ( - weight_cast_key in self.fp8_weight_cache - and self.fp8_weight_cache[weight_cast_key].shape == weight.shape - ): - return - - self.fp8_weight_cache[weight_cast_key] = paddle.empty( - shape=weight.shape, - dtype=paddle.uint8, - ) - - self.fp8_weight_cache[weight_transpose_key] = paddle.empty( - shape=[weight.shape[1], weight.shape[0]], - dtype=paddle.uint8, - ) - - def _get_fp8_state(self) -> paddle.Tensor: - """Dump FP8 state to paddle.Tensor.""" - state = None - if self.fp8_meta["fp8_checkpoint"]: - state = {} - state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy() - state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy() - state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy() - state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy() - # Store other pickelable values. - extra = {} - for k, v in self.fp8_meta.items(): - if isinstance(v, (bool, int, float, str)): - extra[k] = v - state["extra_fp8_variables"] = extra - - state_serialized = pickle.dumps(state) - state_tensor = paddle.to_tensor(np.frombuffer(state_serialized, dtype=np.uint8)) - - return state_tensor - - @paddle.no_grad() - def state_dict( - self, - destination=None, - include_sublayers=True, - structured_name_prefix="", - use_hook=True, - ): - """Save FP8 State when checkpointing.""" - st = super().state_dict( - destination=destination, - include_sublayers=include_sublayers, - structured_name_prefix=structured_name_prefix, - use_hook=use_hook, - ) - st["fp8_state"] = self._get_fp8_state() - return st - - def _set_fp8_state(self, state: paddle.Tensor) -> None: - """Load previous state.""" - if state is None: - return - - state = pickle.loads(state.numpy().tobytes()) - if state is None: - return - - # Load fp8 meta tensors. - self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"]) - self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"]) - - # Restore global FP8 buffer states. - global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() - global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer() - global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"]) - global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"]) - - # Load extra items. - self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ - 0 - ] - recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key() - if recompute_buffer_pos_key in self.fp8_meta: - del self.fp8_meta[recompute_buffer_pos_key] - - @paddle.no_grad() - def set_state_dict(self, state_dict, use_structured_name=True): - """Restore FP8 State from checkpoint.""" - fp8_state_tensor = state_dict.pop("fp8_state") - self._set_fp8_state(fp8_state_tensor) - - return super().set_state_dict(state_dict) - - @contextmanager - def prepare_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Union[bool, None], - num_gemms: int = 1, - ) -> Generator[paddle.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - - if self.fp8_enabled and is_in_recompute_phase(): - global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() - global_recompute_buffer.retrieve_fp8_meta_tensors(self.fp8_meta) - else: - self.set_activation_dtype(inp) - self.fp8_init(num_gemms=num_gemms) - - # Create persistent tensors for fp8 weights and their transposes - # only when fp8 weight caching is used. - if is_first_microbatch is not None: - self.set_fp8_weights() - - if self.fp8_enabled and self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) - - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch - - # Previous iteration was grad_enabled - if self.fp8_meta.get("update_amax_and_scale_fwd", False): - global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() - global_fp8_fwd_buffer.wait() - # Register PP forward begin hook when CUDAGraph is enabled. - # NOTE(tizheng): register_pp_fwd_begin_hook prevents layer parameters from being freed - # when the layer object is deleted. Need to find a better way. - if get_global_fp8_state().is_cudagraph_enabled() and self.current_step_id is None: - self.current_step_id = paddle.to_tensor( - [1], dtype=paddle.int32, place=paddle.CPUPlace() - ) - - def current_step_id_callback( - step_id=None, **kwargs - ): # pylint: disable=unused-argument - self.current_step_id.copy_( - paddle.to_tensor( - [step_id], dtype=paddle.int32, place=paddle.CPUPlace() - ), - True, - ) - - if is_pp_enabled(): - register_pp_fwd_begin_hook(current_step_id_callback) - - if self.fp8_meta["recipe"].reduce_amax: - global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) - amax_and_scale_update( - self.fp8_meta, - fwd_update=True, - update_weight_scale_inv=update_weight_scale_inv, - current_step_id_tensor=self.current_step_id, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) - else: - amax_and_scale_update( - self.fp8_meta, - fwd_update=True, - update_weight_scale_inv=update_weight_scale_inv, - current_step_id_tensor=self.current_step_id, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - - if self.fp8_enabled and self.training: - # Setup for amax reduction - if self.fp8_meta["recipe"].reduce_amax: - global_fp8_state = get_global_fp8_state() - self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module() - self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id() - self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"]) - self.fp8_meta["update_amax_and_scale_fwd"] = True - else: - self.fp8_meta["update_amax_and_scale_fwd"] = False - - # Activation recomputation is used and this is the first forward phase. - if ( - self.fp8_enabled - and self.training - and get_global_fp8_state().is_fp8_recompute_enabled() - ): - global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() - global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta) - - with nvtx_range(self.__class__.__name__ + " forward"): - yield inp - - if self.fp8_enabled and is_in_recompute_phase(): - FP8RecomputeBuffer.restore_fp8_meta_tensors(self.fp8_meta) - return - - if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax: - global_fp8_state = get_global_fp8_state() - global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer() - global_fp8_fwd_buffer.add_amax(self.fp8_meta) - global_fp8_fwd_buffer.set_for_amax_reduction( - self.fp8_meta, - self.tp_group, - self.tp_size, - ) - - @staticmethod - @contextmanager - def prepare_backward( - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - name: str = "", - ) -> Generator[None, None, None]: - """Checks and prep for BWD.""" - if fp8_enabled: - global_fp8_state = get_global_fp8_state() - global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer() - global_fp8_bwd_buffer.wait() - - if fp8_meta["recipe"].reduce_amax: - global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta) - amax_and_scale_update( - fp8_meta, - fwd_update=False, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - global_fp8_bwd_buffer.set_for_deletion(fp8_meta) - - # Get new backward key. - fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) - else: - amax_and_scale_update( - fp8_meta, - fwd_update=False, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - - with nvtx_range(name + " backward"): - yield - - if fp8_enabled and fp8_meta["recipe"].reduce_amax: - global_fp8_bwd_buffer.add_amax(fp8_meta) - if fp8_meta["first_module"]: - global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) - - @staticmethod - def grad_output_preprocess( - ctx, grad_output: paddle.Tensor, row_parallel_mode: bool - ) -> Tuple[Union[paddle.Tensor, None], ...]: - """Utility function for backward. - Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output` in higher precision. - R2: gathered `grad_output` in FP8. - R3: R2 transposed. - R4: bias gradient on R1. - """ - grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1])) - gather_grad_output = row_parallel_mode and ctx.sequence_parallel - - # No-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8_enabled: - if gather_grad_output: - grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group) - return grad_output_mat, None, None, None - - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - if gather_grad_output: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather - if ctx.use_bias: - bgrad = grad_output_mat.sum(axis=0) - else: - bgrad = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - grad_output_c, _ = allgather(grad_output_c, ctx.tp_group) - grad_output_t = transpose(grad_output_c, fp8_dtype_backward) - - return grad_output_mat, grad_output_c, grad_output_t, bgrad - - # FP8 case with gather and non-FP8 wgrad - grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group) - - # FP8 case without gather: cast, transpose, bgrad fused - if ctx.use_bias: - bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - grad_output_c, grad_output_t = cast_transpose( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - grad_output_t = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - bgrad = None - return grad_output_mat, grad_output_c, grad_output_t, bgrad - - @abstractmethod - def forward(self): - """Needs override.""" - - def get_fp8_weights_scratchpad_and_cast( - self, - is_first_microbatch: Union[bool, None], - ) -> List[Optional[paddle.Tensor]]: - """ - Fetch the fp8 weight tensor placeholders if they exist (when - `is_first_microbatch` is not `None`) - """ - if not self.fp8_enabled or is_first_microbatch is None: - return [None, None] * len(self.fp8_weights) - - out_list = [] - for i, _ in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - assert ( - weight_cast_key in self.fp8_weight_cache - ), "TE internal error: fp8 weight buffer is not found" - - weight_fp8 = self.fp8_weight_cache[weight_cast_key] - weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key] - - # Disable fp8 weight cache - # is_first_microbatch is None -> we cast the weights into fp8 every micro step - # Enalbe fp8 weight cache - # is_first_microbatch == true -> we cast the weights into fp8 every micro step - - out_list.extend([weight_fp8, weight_t_fp8]) - - # is cudagraph is enabled we cast the weight before the pp pipe - # we only register the callback once - if get_global_fp8_state().is_cudagraph_enabled() and ( - not self.registered_pp_start_callback and is_pp_enabled() - ): - - fp8_dtype_forward = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) - - def cast_callback(step_id=None, **kwargs): # pylint: disable=unused-argument - update_fp8_weights = step_id == 0 - - for i, weight in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - assert ( - weight_cast_key in self.fp8_weight_cache - ), "TE internal error: fp8 weight buffer is not found" - - weight_fp8 = self.fp8_weight_cache[weight_cast_key] - weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key] - - if paddle.is_grad_enabled(): - if update_fp8_weights: - cast_transpose( - weight, - self.fp8_meta["scaling_fwd"], - ( - FP8FwdTensors.GEMM1_WEIGHT - if i == 1 - else FP8FwdTensors.GEMM2_WEIGHT - ), - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) - else: - if update_fp8_weights: - cast_to_fp8( - weight, - self.fp8_meta["scaling_fwd"], - ( - FP8FwdTensors.GEMM1_WEIGHT - if i == 1 - else FP8FwdTensors.GEMM2_WEIGHT - ), - fp8_dtype_forward, - out=weight_fp8, - ) - - cast_callback(0 if is_first_microbatch else 1) - register_pp_fwd_begin_hook(cast_callback) - self.registered_pp_start_callback = True - return out_list diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py deleted file mode 100644 index be12b6534f..0000000000 --- a/transformer_engine/paddle/layer/layernorm.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Linear API""" - -import os -from typing import Union, Tuple - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from ..constants import TE_DType -from ..cpp_extensions import layernorm_fwd, layernorm_bwd -from ..distributed import mark_as_sequence_parallel_parameter - -__all__ = ["LayerNorm"] - - -class _LayerNorm(paddle.autograd.PyLayer): - """TE Non-FP8 LayerNorm""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: paddle.Tensor, - eps: float, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "LayerNorm not possible" - inputmat = inp.reshape((-1, in_features)) - - ln_out, mu, rsigma = layernorm_fwd( - inputmat, - ln_weight, - ln_bias, - eps, - TE_DType[inp.dtype], - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.requires_dx = not inp.stop_gradient - ctx.requires_dw = not ln_weight.stop_gradient - ctx.requires_dbias = not ln_bias.stop_gradient - return ln_out.reshape(inp.shape) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - inputmat, ln_weight, mu, rsigma = ctx.saved_tensor() - d_ln_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma, dbeta = layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, - dgamma if ctx.requires_dw else None, - dbeta if ctx.requires_dbias else None, - ) - - -class LayerNorm(paddle.nn.Layer): - r""" - Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization `__ - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta - - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of - size :attr:`hidden_size` - - Parameters - ---------- - hidden_size : int - size of each input sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for softmax operation. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - zero_centered_gamma: bool = False, - sequence_parallel: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.sequence_parallel = sequence_parallel - self.backend = backend - self._dtype = self._helper.get_default_dtype() - - self._weight_attr = weight_attr - if not self._weight_attr: - self._weight_attr = paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ) - - self._bias_attr = bias_attr - if self._bias_attr is False: - self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0), trainable=False) - - self.weight = self.create_parameter( - shape=[hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - - self.bias = self.create_parameter( - shape=[hidden_size], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - mark_as_sequence_parallel_parameter(self.bias) - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: - """LayerNorm FWD""" - return _LayerNorm.apply( - inp, - self.weight, - self.bias, - self.eps, - self.fwd_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - ) - - def _pd_forward( - self, - inp: paddle.Tensor, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - return F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.weight, - bias=self.bias, - epsilon=self.eps, - ) - - def forward(self, *args, **kwargs): - """forward""" - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py deleted file mode 100644 index 57c91238e6..0000000000 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ /dev/null @@ -1,721 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""LayerNormLinear API""" - -import warnings -import os -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from ..cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, - layernorm_fwd, - layernorm_fwd_fp8, - layernorm_bwd, - rmsnorm_fwd_fp8, - rmsnorm_fwd, - rmsnorm_bwd, -) - -from .base import TransformerEngineBaseLayer -from .linear import _linear_fwd, _linear_bwd -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type -from ..distributed import ( - allreduce, - get_tp_group_and_world_size, - identity, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - save_for_backward_allow_none, - saved_tensor_allow_none, -) - -__all__ = ["LayerNormLinear"] - - -def _apply_normalization_fwd( - normalization: str, - inputmat: paddle.Tensor, - norm_weight: paddle.Tensor, - norm_bias: Union[paddle.Tensor, None], - out_fp8_index: FP8FwdTensors, - eps: float, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_norm_output: bool, - fwd_norm_sm_margin: int, - zero_centered_gamma: bool, -): - """Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path""" - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - if normalization == "RMSNorm": - assert norm_bias is None, "RMSNorm does not support bias!" - norm_weight = cast_if_needed_inplace(norm_weight, activation_dtype) - if norm_bias is not None: - norm_bias = cast_if_needed_inplace(norm_bias, activation_dtype) - - norm_kwargs = { - "inp": inputmat, - "weight": norm_weight, - "eps": eps, - "otype": TE_DType[activation_dtype], - "sm_margin": fwd_norm_sm_margin, - "zero_centered_gamma": zero_centered_gamma, - } - - fwd_normalization_funcs = { - ("LayerNorm", True, True): layernorm_fwd, - ("LayerNorm", True, False): layernorm_fwd_fp8, - ("LayerNorm", False, True): layernorm_fwd, - ("LayerNorm", False, False): layernorm_fwd, - ("RMSNorm", True, True): rmsnorm_fwd, - ("RMSNorm", True, False): rmsnorm_fwd_fp8, - ("RMSNorm", False, True): rmsnorm_fwd, - ("RMSNorm", False, False): rmsnorm_fwd, - } - - if normalization == "LayerNorm": - norm_kwargs["bias"] = norm_bias - norm_fwd_func = fwd_normalization_funcs[(normalization, fp8_enabled, return_norm_output)] - - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if not return_norm_output: - fp8_kwargs = { - "fp8_meta_tensor": fp8_meta["scaling_fwd"], - "fp8_tensor": out_fp8_index, - "otype": fp8_dtype_forward, - } - norm_kwargs.update(fp8_kwargs) - - out_tuple = norm_fwd_func(**norm_kwargs) - - if normalization == "LayerNorm": - norm_out_return, mu, rsigma = out_tuple - else: # RMSNorm - norm_out_return, rsigma = out_tuple - mu = None - - if fp8_enabled and return_norm_output: - norm_out = cast_to_fp8( - norm_out_return, - fp8_meta["scaling_fwd"], - out_fp8_index, - fp8_dtype_forward, - ) - else: - norm_out = norm_out_return - - return ( - norm_out_return, - norm_out, - mu, - rsigma, - ) - - -def _apply_normalization_bwd( - normalization: str, - inputmat: paddle.Tensor, - dgrad: paddle.Tensor, - norm_weight: paddle.Tensor, - mu: Union[paddle.Tensor, None], - rsigma: paddle.Tensor, - grad_norm_out_return: paddle.Tensor, - return_norm_output: bool, - bwd_norm_sm_margin: int, - zero_centered_gamma: bool, -): - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - if normalization == "RMSNorm": - assert mu is None, "RMSNorm does not support bias!" - # LayerNorm gradient - d_norm_out = dgrad.reshape(inputmat.shape) - # Residual gradient - if return_norm_output: - d_norm_out = d_norm_out + grad_norm_out_return.reshape(d_norm_out.shape) - - norm_bwd_func = layernorm_bwd if normalization == "LayerNorm" else rmsnorm_bwd - norm_bwd_kwargs = { - "dz": d_norm_out, - "x": inputmat, - "rsigma": rsigma, - "gamma": norm_weight, - "sm_margin": bwd_norm_sm_margin, - "zero_centered_gamma": zero_centered_gamma, - } - if normalization == "LayerNorm": - norm_bwd_kwargs["mu"] = mu - - out_tuple = norm_bwd_func(**norm_bwd_kwargs) - if normalization == "LayerNorm": - dxmat, dgamma, dbeta = out_tuple - else: # RMSNorm - dxmat, dgamma = out_tuple - dbeta = None - - return dxmat, dgamma, dbeta - - -class _LayerNormLinear(paddle.autograd.PyLayer): - """TE implementation of LayerNormLinear""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: Union[paddle.Tensor, None], - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - bias: Union[paddle.Tensor, None], - use_bias: bool, - eps: float, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_layernorm_output: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: - if normalization == "RMSNorm": - assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm - assert ln_bias is not None, "LayerNorm requires bias!" - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(weight) - - # LayerNorm Fwd + FP8 Cast - ( - ln_out_return, - ln_out, - mu, - rsigma, - ) = _apply_normalization_fwd( - normalization, - inputmat, - ln_weight, - ln_bias, - FP8FwdTensors.GEMM1_INPUT, - eps, - fp8_enabled, - fp8_meta, - activation_dtype, - return_layernorm_output, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - # Linear Fwd - out, weight_t_fp8 = _linear_fwd( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_fp8, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - bias, - use_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - - if is_grad_enabled: - save_for_backward_allow_none( - ctx, - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8 if fp8_enabled else None, - ln_out, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - - ctx.activation_dtype = activation_dtype - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - ctx.return_layernorm_output = return_layernorm_output - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.parallel_mode = parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_wgrad = not weight.stop_gradient - ctx.requires_bgrad = use_bias and not bias.stop_gradient - ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient - ctx.requires_ln_wgrad = not ln_weight.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.has_ln_bias = ln_bias is not None - ctx.normalization = normalization - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) - - if return_layernorm_output: - return out, ln_out_return.reshape(inp.shape) - return out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, ...] - ) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" - ): - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8, - ln_out, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ( - grad_output, - grad_output_c, - grad_output_t, - bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess( - ctx, grad_outputs[0], ctx.parallel_mode == "row" - ) - - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - # Prepare ln_out for Linear bwd - linear_inputmat = ln_out - if ctx.fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - if ctx.requires_wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - linear_inputmat = cast_from_fp8( - ln_out, - ctx.fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - - # Linear Bwd - dgrad, wgrad, bgrad_ = _linear_bwd( - linear_inputmat, - None, # inputmat_t will be automatically computed if not provided - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.requires_bgrad, - ctx.fp8_enabled, - ctx.fp8_meta, - True, # Always compute dgrad to feed into LayerNorm bwd - ctx.requires_wgrad, - ctx.activation_dtype, - ctx.parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - if not ctx.fp8_enabled: - # bgrad is fused with gemm for non-FP8 path - bgrad = bgrad_ - - # LayerNorm Bwd - dxmat, dgamma, dbeta = _apply_normalization_bwd( - ctx.normalization, - inputmat, - dgrad, - ln_weight, - mu, - rsigma, - grad_outputs[1] if ctx.return_layernorm_output else None, - ctx.return_layernorm_output, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - - bgrad = bgrad if ctx.requires_bgrad else None - bgrad_out = (bgrad,) if ctx.use_bias else () - dbeta = dbeta if ctx.requires_ln_bgrad else None - dbeta_out = (dbeta,) if ctx.has_ln_bias else () - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - weight_cache_grad = (None, None) - - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - wgrad = None - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma if ctx.requires_ln_wgrad else None, - *dbeta_out, - wgrad if ctx.requires_wgrad else None, - *weight_cache_grad, - *bgrad_out, - ) - - -class LayerNormLinear(TransformerEngineBaseLayer): - r""" - Applies layer normalization followed by linear transformation to the incoming data. - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module is - taken post layernorm. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - """ - - def __init__( - self, - in_features: int, - out_features: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = "LayerNorm", - return_layernorm_output: bool = False, - zero_centered_gamma: bool = False, - parallel_mode: Optional[str] = None, - sequence_parallel: bool = False, - tp_group: Union[dist_group_type, None] = None, - fuse_wgrad_accumulation: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.eps = eps - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - self.return_layernorm_output = return_layernorm_output - self.zero_centered_gamma = zero_centered_gamma - self.backend = backend - - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=parallel_mode is not None - ) - self.tensor_parallel = self.tp_size > 1 - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - # LayerNorm weights - self.ln_weight = self.create_parameter( - shape=[self.in_features], - attr=paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ), - dtype=self._dtype, - is_bias=False, - ) - if self.normalization != "RMSNorm": - self.ln_bias = self.create_parameter( - shape=[self.in_features], - attr=paddle.ParamAttr(initializer=Constant(value=0.0)), - dtype=self._dtype, - is_bias=True, - ) - else: - self.ln_bias = None - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.ln_weight) - if self.ln_bias is not None: - mark_as_sequence_parallel_parameter(self.ln_bias) - - # Initialize Linear weight parameter - with track_rng_state(enable=self.tensor_parallel): - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=( - [self.out_features, self.in_features] - if self.backend == "transformer_engine" - else [self.in_features, self.out_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.weight, self.tensor_parallel, self.parallel_mode, self.backend - ) - self.fp8_weights.append(self.weight) - - # Initialize Linear bias parameter - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if self.has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=( - self._bias_attr - if not use_default_bias - else paddle.ParamAttr(initializer=Constant(value=0.0)) - ), - dtype=self._dtype, - is_bias=True, - ) - if parallel_mode == "column": - set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) - if parallel_mode == "row" and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.bias) - else: - self.bias = None - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - """ - - with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - - out = _LayerNormLinear.apply( - inp, - self.ln_weight, - self.ln_bias, - self.weight, - weight_fp8, - weight_t_fp8, - self.bias if self.gemm_bias_fused_add else None, - self.has_bias and self.gemm_bias_fused_add, - self.eps, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - self.return_layernorm_output, - paddle.is_grad_enabled(), - self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - ) - - if self.return_layernorm_output: - out, ln_out = out - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) - - if self.return_layernorm_output: - return out, ln_out - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - - if self.normalization == "RMSNorm": - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps, - ) - - if self.parallel_mode == "column" and self.tensor_parallel: - norm_out = identity(norm_out, self.tp_group) - out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == "row" and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.bias if self.bias is not None else out - if self.return_layernorm_output: - return out, norm_out - return out - - def forward(self, *args, **kwargs): - """ - Apply layer normalization to the input followed by a linear transformation. - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py deleted file mode 100644 index 069fb82c69..0000000000 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ /dev/null @@ -1,1010 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""LayerNormMLP API""" - -import os -import warnings -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from .base import TransformerEngineBaseLayer -from .layernorm_linear import _apply_normalization_fwd, _apply_normalization_bwd -from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type -from ..cpp_extensions import ( - cast_from_fp8, - gelu_fp8, - swiglu_fp8, - swiglu, - dswiglu, - cast_transpose_bgrad, - dgelu_cast_transpose_bgrad_fp8, -) -from ..distributed import ( - allreduce, - get_tp_group_and_world_size, - identity, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - get_paddle_act_func, - save_for_backward_allow_none, - saved_tensor_allow_none, -) - -__all__ = ["LayerNormMLP"] - - -def _mlp_forward( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - fc1_weight: paddle.Tensor, - fc1_weight_fp8: Optional[paddle.Tensor], - fc1_weight_t_fp8: Optional[paddle.Tensor], - fc1_weight_fp8_index: FP8FwdTensors, - fc1_bias: Union[paddle.Tensor, None], - use_fc1_bias: bool, - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT - fc2_weight: paddle.Tensor, - fc2_weight_fp8: Optional[paddle.Tensor], - fc2_weight_t_fp8: Optional[paddle.Tensor], - fc2_weight_fp8_index: FP8FwdTensors, - fc2_bias: Union[paddle.Tensor, None], - use_fc2_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - activation: str, - is_grad_enabled: bool, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_first_microbatch: bool, -): - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fc1_out, fc1_weight_t_fp8 = _linear_fwd_fp8( - inputmat, - inputmat_fp8_index, - fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - fc1_weight_fp8_index, - fc1_bias, - use_fc1_bias, - fp8_meta, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - if activation == "gelu": - gelu_out = gelu_fp8( - fc1_out, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - ) - elif activation == "swiglu": - gelu_out = swiglu_fp8( - fc1_out, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - ) - else: - raise NotImplementedError("Activation type " + activation + " is not supported!") - - fc2_out, fc2_weight_t_fp8 = _linear_fwd_fp8( - gelu_out, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - fc2_weight_fp8_index, - fc2_bias, - use_fc2_bias, - fp8_meta, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - else: - fc1_outputs = _linear_fwd_non_fp8( - inputmat, - inputmat_fp8_index, - fc1_weight, - fc1_weight_fp8_index, - fc1_bias, - use_fc1_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - activation=activation, - ) - - if activation == "gelu": - fc1_out, gelu_out = fc1_outputs - elif activation == "swiglu": - fc1_out = fc1_outputs - gelu_out = swiglu(fc1_out, TE_DType[activation_dtype]) - else: - raise NotImplementedError("Activation type " + activation + " is not supported!") - - fc2_out = _linear_fwd_non_fp8( - gelu_out, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_fp8_index, - fc2_bias, - use_fc2_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - ) - return ( - fc1_out, - gelu_out, - fc2_out, - fc1_weight_t_fp8 if fp8_enabled else None, - fc2_weight_t_fp8 if fp8_enabled else None, - ) - - -def _mlp_backward( - fc1_input: paddle.Tensor, # ln_out, BF16 / FP8 - fc1_input_fp8_index: FP8FwdTensors, - fc1_weight: paddle.Tensor, - fc1_weight_t_fp8: paddle.Tensor, - fc1_weight_fp8_index: FP8FwdTensors, - fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2 - requires_fc1_wgrad: bool, - requires_fc1_bgrad: bool, - fc1_out: paddle.Tensor, - fc2_input: paddle.Tensor, # gelu_out - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT - fc2_weight: paddle.Tensor, - fc2_weight_t_fp8: paddle.Tensor, - fc2_weight_fp8_index: FP8FwdTensors, - requires_fc2_wgrad: bool, - requires_fc2_bgrad: bool, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT1 - fwd_scale_inverses: paddle.Tensor, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - activation_dtype: paddle.dtype, - activation: str, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad, - ) = ( - None, - None, - None, - None, - None, - ) - - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - # FC2 Bwd - fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad - if requires_fc2_wgrad and not fp8_wgrad: - fc2_input = cast_from_fp8( - fc2_input, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - TE_DType[activation_dtype], - ) - - fc2_dgrad, fc2_wgrad = _linear_bwd_fp8( - fc2_input, - None, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_t_fp8, - fc2_weight_fp8_index, - grad_output, - grad_output_c, - grad_output_t, - grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - True, - requires_fc2_wgrad, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - dgelu_t = None - fc1_bgrad_ = None - if activation == "gelu": - # GELU Bwd - dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( - fc2_dgrad, - fc1_out, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - ) - elif activation == "swiglu": - dgelu = dswiglu(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - fc1_bgrad_, dgelu, dgelu_t = cast_transpose_bgrad( - dgelu, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - ) - - if requires_fc1_bgrad: - fc1_bgrad = fc1_bgrad_ - - # FC1 Bwd - dgelu_no_fp8 = None - if requires_fc1_wgrad and not fp8_wgrad: - # TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead. - dgelu_no_fp8 = cast_from_fp8( - dgelu, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - TE_DType[activation_dtype], - ) - fc1_input = cast_from_fp8( - fc1_input, - fp8_meta["scaling_fwd"], - fc1_input_fp8_index, - fp8_dtype_forward, - TE_DType[activation_dtype], - ) - - fc1_dgrad, fc1_wgrad = _linear_bwd_fp8( - fc1_input, - None, - fc1_input_fp8_index, - fc1_weight, - fc1_weight_t_fp8, - fc1_weight_fp8_index, - dgelu_no_fp8, - dgelu, - dgelu_t, - fc1_grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - requires_dgrad, - requires_fc1_wgrad, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - else: - dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8( - fc2_input, - fc2_weight, - grad_output, - requires_fc2_bgrad, - True, - requires_fc2_wgrad, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - gelu_input=fc1_out, - activation=activation, - ) - - if activation == "swiglu": - dgelu = dswiglu(dgelu, fc1_out, TE_DType[dgelu.dtype]) - - fc1_dgrad, fc1_wgrad, fc1_bgrad = _linear_bwd_non_fp8( - fc1_input, - fc1_weight, - dgelu, - requires_fc1_bgrad, - requires_dgrad, - requires_fc1_wgrad, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - return ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad, - ) - - -class _LayerNormMLP(paddle.autograd.PyLayer): - """TE implementation of LayerNormMLP""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: Union[paddle.Tensor, None], - fc1_weight: paddle.Tensor, - fc1_weight_fp8: Optional[paddle.Tensor], - fc1_weight_t_fp8: Optional[paddle.Tensor], - fc1_bias: Union[paddle.Tensor, None], - use_fc1_bias: bool, - fc2_weight: paddle.Tensor, - fc2_weight_fp8: Optional[paddle.Tensor], - fc2_weight_t_fp8: Optional[paddle.Tensor], - fc2_bias: Union[paddle.Tensor, None], - use_fc2_bias: bool, - eps: float, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_layernorm_output: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - activation: str, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: - if normalization == "RMSNorm": - assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm - assert ln_bias is not None, "LayerNorm requires bias!" - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(fc1_weight) - assert_dim_for_fp8_forward_exec(fc2_weight) - - # only support gelu for now - assert activation in ["gelu", "swiglu"], "Only gelu and swiglu are supported for now" - - # LayerNorm Fwd + FP8 Cast - ( - ln_out_return, - ln_out, - mu, - rsigma, - ) = _apply_normalization_fwd( - normalization, - inputmat, - ln_weight, - ln_bias, - FP8FwdTensors.GEMM1_INPUT, - eps, - fp8_enabled, - fp8_meta, - activation_dtype, - return_layernorm_output, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - ( - fc1_out, - gelu_out, - fc2_out, - fc1_weight_t_fp8, - fc2_weight_t_fp8, - ) = _mlp_forward( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - fc1_bias, - use_fc1_bias, - FP8FwdTensors.GEMM2_INPUT, - fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - FP8FwdTensors.GEMM2_WEIGHT, - fc2_bias, - use_fc2_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - activation, - is_grad_enabled, - set_parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_first_microbatch, - ) - - if is_grad_enabled: - save_for_backward_allow_none( - ctx, - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - ctx.activation_dtype = activation_dtype - ctx.activation = activation - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_fc1_bias = use_fc1_bias - ctx.use_fc2_bias = use_fc2_bias - ctx.inp_shape = inp.shape - ctx.return_layernorm_output = return_layernorm_output - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.set_parallel_mode = set_parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient - ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient - ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient - ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient - ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient - ctx.requires_ln_wgrad = not ln_weight.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.has_ln_bias = ln_bias is not None - ctx.normalization = normalization - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1])) - - if return_layernorm_output: - return fc2_out, ln_out_return.reshape(inp.shape) - return fc2_out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, ...] - ) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" - ): - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess - ( - grad_output, - grad_output_c, - grad_output_t, - fc2_bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True) - - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad_, - ) = _mlp_backward( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - fc1_weight, - fc1_weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - FP8BwdTensors.GRAD_OUTPUT2, - ctx.requires_fc1_wgrad, - ctx.requires_fc1_bgrad, - fc1_out, - gelu_out, - FP8FwdTensors.GEMM2_INPUT, - fc2_weight, - fc2_weight_t_fp8, - FP8FwdTensors.GEMM2_WEIGHT, - ctx.requires_fc2_wgrad, - ctx.requires_fc2_bgrad, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.fp8_enabled, - ctx.fp8_meta, - True, - ctx.activation_dtype, - ctx.activation, - ctx.set_parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - if not ctx.fp8_enabled: - # fc2_bias is fused with gemm for non-FP8 path - fc2_bgrad = fc2_bgrad_ - - # LayerNorm Bwd - dxmat, dgamma, dbeta = _apply_normalization_bwd( - ctx.normalization, - inputmat, - fc1_dgrad, - ln_weight, - mu, - rsigma, - grad_outputs[1] if ctx.return_layernorm_output else None, - ctx.return_layernorm_output, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - - fc1_bgrad = fc1_bgrad if ctx.requires_fc1_bgrad else None - fc2_bgrad = fc2_bgrad if ctx.requires_fc2_bgrad else None - fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else () - fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else () - dbeta = dbeta if ctx.requires_ln_bgrad else None - dbeta_out = (dbeta,) if ctx.has_ln_bias else () - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - fc1_weight_cache_grad = () - fc2_weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - fc1_weight_cache_grad = (None, None) - fc2_weight_cache_grad = (None, None) - - if ctx.requires_fc1_wgrad and ctx.fuse_wgrad_accumulation: - fc1_wgrad = None - if ctx.requires_fc2_wgrad and ctx.fuse_wgrad_accumulation: - fc2_wgrad = None - - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma if ctx.requires_ln_wgrad else None, - *dbeta_out, - fc1_wgrad if ctx.requires_fc1_wgrad else None, - *fc1_weight_cache_grad, - *fc1_bgrad_out, - fc2_wgrad if ctx.requires_fc2_wgrad else None, - *fc2_weight_cache_grad, - *fc2_bgrad_out, - ) - - -class LayerNormMLP(TransformerEngineBaseLayer): - r""" - Applies layer normalization on the input followed by the MLP module, consisting of - 2 successive linear transformations, separated by the GeLU activation. - - Parameters - ---------- - hidden_size : int - size of each input sample. - ffn_hidden_size : int - intermediate size to which input samples are projected. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - activation : str, default = 'gelu' - activation function used. - Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module - is taken post layernorm. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row - Parallel as described `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : paddle.distributed.collective.Group, default = `None` - tensor parallel process group. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = "LayerNorm", - activation: str = "gelu", - return_layernorm_output: bool = False, - zero_centered_gamma: bool = False, - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.eps = eps - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Normalization type not supported" - self.activation = activation - self.return_layernorm_output = return_layernorm_output - self.zero_centered_gamma = zero_centered_gamma - self.backend = backend - - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.set_parallel_mode = set_parallel_mode - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - if self.set_parallel_mode: - self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) - else: - self.size_per_partition = self.ffn_hidden_size - - # LayerNorm weights - self.ln_weight = self.create_parameter( - shape=[self.hidden_size], - attr=paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ), - dtype=self._dtype, - is_bias=False, - ) - - if self.normalization != "RMSNorm": - self.ln_bias = self.create_parameter( - shape=[self.hidden_size], - attr=paddle.ParamAttr(initializer=Constant(value=0.0)), - dtype=self._dtype, - is_bias=True, - ) - else: - self.ln_bias = None - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.ln_weight) - if self.ln_bias is not None: - mark_as_sequence_parallel_parameter(self.ln_bias) - - # FC1 weights - if self.activation in ["swiglu"]: - fc1_output_features = self.size_per_partition * 2 - else: - fc1_output_features = self.size_per_partition - - with track_rng_state(enable=self.tensor_parallel): - self.fc1_weight = self.create_parameter( - shape=( - [fc1_output_features, self.hidden_size] - if self.backend == "transformer_engine" - else [self.hidden_size, fc1_output_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend - ) - self.fp8_weights.append(self.fc1_weight) - - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if use_default_bias: - self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0)) - - if self.has_bias: - self.fc1_bias = self.create_parameter( - shape=[fc1_output_features], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - set_tensor_dist_attr(self.fc1_bias, self.tensor_parallel, axis=0) - else: - self.fc1_bias = None - - # FC2 weights - self.fc2_weight = self.create_parameter( - shape=( - [self.hidden_size, self.size_per_partition] - if self.backend == "transformer_engine" - else [self.size_per_partition, self.hidden_size] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend - ) - self.fp8_weights.append(self.fc2_weight) - - if self.has_bias: - self.fc2_bias = self.create_parameter( - shape=[self.hidden_size], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - if self.set_parallel_mode and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.fc2_bias) - else: - self.fc2_bias = None - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.set_parallel_mode and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - """ - - with self.prepare_forward(inp, num_gemms=2, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = ( - self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - ) - - out = _LayerNormMLP.apply( - inp, - self.ln_weight, - self.ln_bias, - self.fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - self.fc1_bias, - self.has_bias, - self.fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - self.fc2_bias, - self.has_bias, - self.eps, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - self.return_layernorm_output, - paddle.is_grad_enabled(), - self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.activation, - self.set_parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - ) - - if self.return_layernorm_output: - out, ln_out = out - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.fc2_bias, self.activation_dtype) - - if self.return_layernorm_output: - return out, ln_out - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - - if self.normalization == "RMSNorm": - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps, - ) - if self.set_parallel_mode and self.tensor_parallel: - norm_out = identity(norm_out, self.tp_group) - fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias) - act_func = get_paddle_act_func(self.activation) - act_out = act_func(fc1_out) - out = F.linear( - act_out, self.fc2_weight, self.fc2_bias if self.gemm_bias_fused_add else None - ) - if self.set_parallel_mode and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.fc2_bias if self.fc2_bias is not None else out - if self.return_layernorm_output: - return out, norm_out - return out - - def forward(self, *args, **kwargs): - """ - Apply layer normalization to the input followed by a feedforward network (MLP Block). - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py deleted file mode 100644 index 78b22ac7e4..0000000000 --- a/transformer_engine/paddle/layer/linear.py +++ /dev/null @@ -1,919 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Linear API""" - -import warnings -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from .base import ( - TransformerEngineBaseLayer, - get_workspace, - _2X_ACC_FPROP, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, -) - -from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type -from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose -from ..distributed import ( - allgather, - allreduce, - get_tp_group_and_world_size, - identity, - reduce_scatter, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype, get_global_fp8_state -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - get_bias_dtype, - save_for_backward_allow_none, - saved_tensor_allow_none, - clear_tensor_data, -) - -__all__ = ["Linear"] - - -def _linear_fwd_fp8( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_grad_enabled: bool, - is_first_microbatch: bool = None, -): - """FP8 path of Linear Fwd""" - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - bias_dtype = get_bias_dtype(activation_dtype) - bias = cast_if_needed(bias, bias_dtype) - - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = allgather(inputmat, tp_group) - else: - inputmat_total = inputmat - - if not get_global_fp8_state().is_cudagraph_enabled(): - # if cuda graph is not enabled, we cast the weight here - update_fp8_weights = is_first_microbatch is None or is_first_microbatch - if is_grad_enabled: - if update_fp8_weights: - weight_fp8, weight_t_fp8 = cast_transpose( - weight, - fp8_meta["scaling_fwd"], - weight_fp8_index, - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) - else: - weight_t_fp8 = None - if update_fp8_weights: - weight_fp8 = cast_to_fp8( - weight, - fp8_meta["scaling_fwd"], - weight_fp8_index, - fp8_dtype_forward, - out=weight_fp8, - ) - - out, _ = fp8_gemm( - weight_fp8, - fp8_meta["scaling_fwd"].scale_inv, - weight_fp8_index, - fp8_dtype_forward, - inputmat_total, - fp8_meta["scaling_fwd"].scale_inv, - inputmat_fp8_index, - fp8_dtype_forward, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ) - - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - - return out, weight_t_fp8 - - -def _linear_fwd_non_fp8( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - activation: str = "", -): - """Non-FP8 path of Linear Fwd""" - - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = allgather(inputmat, tp_group) - else: - inputmat_total = inputmat - - # Layer parameters are initialized as float32 dtype by default. - # Cast the parameters to activation_dtype if the current dtype - # does not match activation_dtype. The casting is inplace, so it - # only needs to performed once throughout the traing process. - weight = cast_if_needed_inplace(weight, activation_dtype) - bias = cast_if_needed_inplace(bias, activation_dtype) - - if fp8_calibration: - # amax of input - fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max( - paddle.abs(inputmat_total) - ).item() - # amax of weight - fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max( - paddle.abs(weight) - ).item() - fp8_meta["update_amax_and_scale_fwd"] = True - - outputs = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - gelu=(activation == "gelu"), - ) - - if activation == "gelu": - gelu_out, _, out = outputs - return out, gelu_out - - out, _, _ = outputs - - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - return out - - -def _linear_fwd( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_grad_enabled: bool, - is_first_microbatch: bool = None, - gather_output: bool = False, -): - if fp8_enabled: - out, weight_t_fp8 = _linear_fwd_fp8( - inputmat, - inputmat_fp8_index, - weight, - weight_fp8, - weight_t_fp8, - weight_fp8_index, - bias, - use_bias, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - else: - out = _linear_fwd_non_fp8( - inputmat, - inputmat_fp8_index, - weight, - weight_fp8_index, - bias, - use_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - ) - if gather_output and tensor_parallel and parallel_mode == "column": - out, _ = allgather(out, tp_group, axis=-1) - - return ( - out, - weight_t_fp8 if fp8_enabled else None, - ) - - -def _linear_bwd_fp8( - inputmat: paddle.Tensor, - inputmat_t: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_t_fp8: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, - fwd_scale_inverses: paddle.Tensor, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - dgrad, wgrad, handle = None, None, None - - # Overlap input AG with dgrad - inputmat_total = None - inputmat_t_total = None - if requires_wgrad and parallel_mode == "column" and sequence_parallel: - inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad) - else: - inputmat_total = inputmat - inputmat_t_total = inputmat_t - - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - if requires_dgrad: - dgrad, _ = fp8_gemm( - weight_t_fp8, - fwd_scale_inverses, - weight_fp8_index, - fp8_dtype_forward, - grad_output_c, - fp8_meta["scaling_bwd"].scale_inv, - grad_output_fp8_index, - fp8_dtype_backward, - activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) - clear_tensor_data(grad_output_c) - - # Overlap dgrad-RS/AR with wgrad - if parallel_mode == "column" and sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False) - elif parallel_mode == "column" and tensor_parallel: - dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) - - if requires_wgrad: - if not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t_total is None: - inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward) - clear_tensor_data(inputmat_total) - - wgrad, _ = fp8_gemm( - inputmat_t_total, - fwd_scale_inverses, - inputmat_fp8_index, - fp8_dtype_forward, - grad_output_t, - fp8_meta["scaling_bwd"].scale_inv, - grad_output_fp8_index, - fp8_dtype_backward, - "float32" if fuse_wgrad_accumulation else activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ) - clear_tensor_data(inputmat_t_total, grad_output_t) - else: - wgrad, _, _ = gemm( - inputmat_total, - grad_output, - activation_dtype, - get_workspace(), - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - layout="NT", - out=weight.main_grad if fuse_wgrad_accumulation else None, - out_dtype="float32" if fuse_wgrad_accumulation else None, - ) - clear_tensor_data(inputmat_total) - - if fuse_wgrad_accumulation: - weight.main_grad = wgrad - - if parallel_mode == "column" and tensor_parallel and handle is not None: - handle.wait() - if parallel_mode == "column" and sequence_parallel: - handle.wait() - - return dgrad, wgrad - - -def _linear_bwd_non_fp8( - inputmat: paddle.Tensor, - weight: paddle.Tensor, - grad_output: paddle.Tensor, - requires_bgrad: bool, - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, - gelu_input: Union[paddle.Tensor, None] = None, - activation: str = "", -): - """ - Performs Linear Backward. Optionally, fuses GELU backward and dbias. - """ - dgrad, wgrad, bgrad, handle = None, None, None, None - - # Overlap input AG with dgrad - inputmat_total = None - if requires_wgrad and parallel_mode == "column" and sequence_parallel: - inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad) - else: - inputmat_total = inputmat - - if requires_dgrad: - dgrad, _, _ = gemm( - weight, - grad_output, - activation_dtype, - get_workspace(), - layout="NN", - gelu=(activation == "gelu"), - gelu_input=gelu_input, - grad=True, - ) - # Overlap dgrad-RS/AR with wgrad - if parallel_mode == "column" and sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False) - elif parallel_mode == "column" and tensor_parallel: - dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) - - if requires_wgrad: - wgrad, bgrad, _ = gemm( - inputmat_total, - grad_output, - activation_dtype, - get_workspace(), - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - layout="NT", - out=weight.main_grad if fuse_wgrad_accumulation else None, - out_dtype="float32" if fuse_wgrad_accumulation else None, - use_bias=requires_bgrad, - ) - if fuse_wgrad_accumulation: - weight.main_grad = wgrad - - elif requires_bgrad: - bgrad = grad_output.sum(axis=0) - if parallel_mode == "column" and tensor_parallel and handle is not None: - handle.wait() - if parallel_mode == "column" and sequence_parallel and handle is not None: - handle.wait() - - return dgrad, wgrad, bgrad - - -def _linear_bwd( - inputmat: paddle.Tensor, - inputmat_t: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_t_fp8: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, - fwd_scale_inverses: paddle.Tensor, - requires_bgrad: bool, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - dgrad, wgrad, bgrad = None, None, None - if fp8_enabled: - dgrad, wgrad = _linear_bwd_fp8( - inputmat, - inputmat_t, - inputmat_fp8_index, - weight, - weight_t_fp8, - weight_fp8_index, - grad_output, - grad_output_c, - grad_output_t, - grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - requires_dgrad, - requires_wgrad, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - else: - dgrad, wgrad, bgrad = _linear_bwd_non_fp8( - inputmat, - weight, - grad_output, - requires_bgrad, - requires_dgrad, - requires_wgrad, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - return dgrad, wgrad, bgrad - - -class _Linear(paddle.autograd.PyLayer): - """TE implementation of Linear""" - - @staticmethod - def forward( - ctx, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - inp: paddle.Tensor, - bias: paddle.Tensor, - use_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - is_grad_enabled: bool, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - gather_output: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = weight.shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(weight) - - inputmat_no_fp8 = inputmat - - # FP8 casting - inputmat_t = None - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and not sequence_parallel - ): - inputmat, inputmat_t = cast_transpose( - inputmat, - fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - else: - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - - # GEMM Fwd - out, weight_t_fp8 = _linear_fwd( - inputmat, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_fp8, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - bias, - use_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - gather_output, - ) - - if is_grad_enabled: - saved_inputmat = None - if fp8_enabled and sequence_parallel: - saved_inputmat = inputmat - else: - saved_inputmat = inputmat_no_fp8 - save_for_backward_allow_none( - ctx, - saved_inputmat, - inputmat_t, - weight, - weight_t_fp8 if fp8_enabled else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - ctx.activation_dtype = activation_dtype - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_wgrad = not weight.stop_gradient - ctx.requires_bgrad = use_bias and not bias.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.reduce_scatter_output = gather_output - - return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" - ): - - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - inputmat_t, - weight, - weight_t_fp8, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ( - grad_output, - grad_output_c, - grad_output_t, - bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess( - ctx, grad_output, ctx.parallel_mode == "row" - ) - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - dgrad, wgrad, bgrad_ = _linear_bwd( - inputmat, - inputmat_t, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.requires_bgrad, - ctx.fp8_enabled, - ctx.fp8_meta, - ctx.requires_dgrad, - ctx.requires_wgrad, - ctx.activation_dtype, - ctx.parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - if not ctx.fp8_enabled: - # bgrad is fused with gemm for non-FP8 path - bgrad = bgrad_ - - if ctx.reduce_scatter_output: - wgrad, _ = reduce_scatter(wgrad, ctx.tp_group) - bgrad, _ = reduce_scatter(bgrad, ctx.tp_group) - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - weight_cache_grad = (None, None) - - dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None - if not ctx.use_bias: - bgrad_return = () - elif ctx.requires_bgrad: - bgrad_return = (bgrad,) - else: - bgrad_return = (None,) - - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - wgrad = None - - return ( - wgrad if ctx.requires_wgrad else None, - *weight_cache_grad, - dgrad_return, - *bgrad_return, - ) - - -class Linear(TransformerEngineBaseLayer): - """ - Applies a linear transformation to the incoming data :math:`y = xA^T + b` - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - in_features: int, - out_features: int, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - parallel_mode: Optional[str] = None, - sequence_parallel: bool = False, - tp_group: Union[dist_group_type, None] = None, - fuse_wgrad_accumulation: bool = False, - gather_output: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.backend = backend - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - self.gather_output = gather_output - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=parallel_mode is not None - ) - self.tensor_parallel = self.tp_size > 1 - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - # Initialize weight parameter - with track_rng_state(enable=self.tensor_parallel): - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=( - [self.out_features, self.in_features] - if self.backend == "transformer_engine" - else [self.in_features, self.out_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.weight, self.tensor_parallel, self.parallel_mode, self.backend - ) - - # Initialize bias parameter - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if self.has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=( - self._bias_attr - if not use_default_bias - else paddle.ParamAttr(initializer=Constant(value=0.0)) - ), - dtype=self._dtype, - is_bias=True, - ) - if parallel_mode == "column": - set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) - if parallel_mode == "row" and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.bias) - else: - self.bias = None - - self.fp8_weights.append(self.weight) - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """ - Apply the linear transformation to the input. - """ - with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - - out = _Linear.apply( - self.weight, - weight_fp8, - weight_t_fp8, - inp, - self.bias if self.gemm_bias_fused_add else None, - self.has_bias and self.gemm_bias_fused_add, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - paddle.is_grad_enabled(), - self.parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - self.gather_output, - ) - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) - - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - if self.parallel_mode == "column" and self.tensor_parallel: - inp = identity(inp, self.tp_group) - out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == "row" and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.bias if self.bias is not None else out - return out - - def forward(self, *args, **kwargs): - """ - Apply the linear transformation to the input. - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/rmsnorm.py b/transformer_engine/paddle/layer/rmsnorm.py deleted file mode 100644 index 23e406e3fb..0000000000 --- a/transformer_engine/paddle/layer/rmsnorm.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""RMSNorm API""" -import os -from typing import Union, Tuple - -import paddle -from paddle.nn.initializer import Constant - -from ..constants import TE_DType -from ..cpp_extensions import rmsnorm_fwd, rmsnorm_bwd -from ..distributed import mark_as_sequence_parallel_parameter - -__all__ = ["RMSNorm"] - - -class _RMSNorm(paddle.autograd.PyLayer): - """functional RMSNorm""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - rmsnorm_weight: paddle.Tensor, - eps: float, - fwd_rmsnorm_sm_margin: int, - bwd_rmsnorm_sm_margin: int, - zero_centered_gamma: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = rmsnorm_weight.shape[0] - assert inp.shape[-1] == in_features, "RMSNorm not possible" - inputmat = inp.reshape((-1, in_features)) - - rmsnorm_out, rsigma = rmsnorm_fwd( - inputmat, - rmsnorm_weight, - eps, - TE_DType[inp.dtype], - fwd_rmsnorm_sm_margin, - zero_centered_gamma, - ) - - ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.requires_dx = not inp.stop_gradient - ctx.requires_dw = not rmsnorm_weight.stop_gradient - - return rmsnorm_out.reshape(inp.shape) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor() - d_rmsnorm_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma = rmsnorm_bwd( - d_rmsnorm_out, - inputmat, - rsigma, - rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, - ctx.zero_centered_gamma, - ) - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, - dgamma if ctx.requires_dw else None, - ) - - -class RMSNorm(paddle.nn.Layer): - r""" - Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in - the paper `Root Mean Square Layer Normalization `__ - - .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * \gamma - - where - - .. math:: - RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon} - - :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` - - Parameters - ---------- - hidden_size : int - size of each input sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in RMSNorm is initialized to 0 and - the RMSNorm formula changes to - - .. math:: - y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma) - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - backend to use for rmsnorm operation. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - zero_centered_gamma: bool = False, - sequence_parallel: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.sequence_parallel = sequence_parallel - self.backend = backend - self._dtype = self._helper.get_default_dtype() - - self._weight_attr = weight_attr - if not self._weight_attr: - self._weight_attr = paddle.ParamAttr(initializer=Constant(1.0)) - - self.weight = self.create_parameter( - shape=[hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - - # These many SMs are subtracted from the total SM count when calling forward - # and backward RMSNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with RMSNorm. - self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: - return _RMSNorm.apply( - inp, - self.weight, - self.eps, - self.fwd_rmsnorm_sm_margin, - self.bwd_rmsnorm_sm_margin, - self.zero_centered_gamma, - ) - - def _pd_forward( - self, - inp: paddle.Tensor, - ) -> paddle.Tensor: - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support RMSNorm with zero_centered_gamma." - ) - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - y = inp * norm * self.weight - return y - - def forward(self, *args, **kwargs): - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} not supported.") diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py deleted file mode 100644 index 971be68167..0000000000 --- a/transformer_engine/paddle/layer/softmax.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Fused scaled masked softmax functions""" - -import os -import warnings -from typing import Callable, Tuple, Union, Optional - -import paddle - -from transformer_engine.paddle.cpp_extensions import ( - scaled_upper_triang_masked_softmax_forward, - scaled_upper_triang_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_backward, - scaled_softmax_forward, - scaled_softmax_backward, -) - - -__all__ = ["FusedScaleMaskSoftmax"] - - -THREADS_PER_WARP = 32 -THREADS_PER_BLOCK = 128 - - -_default_causal_mask = {} - - -def _get_default_causal_mask(seqlen: int) -> paddle.Tensor: - """Return the causal upper triangular mask for softmax input""" - if seqlen not in _default_causal_mask: - _default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), diagonal=1).cast( - "bool" - ) - return _default_causal_mask[seqlen] - - -class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledUpperTriangMaskedSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledUpperTriangMaskedSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -class ScaledMaskedSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, mask: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledMaskedSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledMaskedSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class ScaledSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - - softmax_results = scaled_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - - input_grads = scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(paddle.nn.Layer): - """ - Scaled and masked softmax module for paddle with fused optimizations. - - Parameters - ---------- - attn_mask_type : str, default = `causal` - type of attention mask, can be 'causal', 'padding', or 'no_mask'. - mask_func : callable - custom callable for applying the mask to the softmax input. - `masked_input=mask_func(inp, mask)`. - softmax_in_fp32 : bool, default = True - perform softmax computation in fp32. - layernorm_epsilon : float, default = 1e-5 - a value added to the denominator of layer normalization - for numerical stability. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for operation. - """ - - def __init__( - self, - attn_mask_type: str, - mask_func: Callable, - softmax_in_fp32: bool = True, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))) - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.backend = backend - - def forward( - self, - inp: paddle.Tensor, - mask: paddle.Tensor, - scale: Optional[float] = None, - ) -> paddle.Tensor: - """FusedScaleMaskSoftmax fprop""" - # [batch_size, num_heads, s_q, s_kv] - assert inp.dim() == 4 - self.input_is_fp16 = inp.dtype == paddle.float16 - self.input_is_bf16 = inp.dtype == paddle.bfloat16 - self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16 - - assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - - if self.backend == "transformer_engine" and not self.is_kernel_available(*inp.shape): - warnings.warn( - "fused kernel is not available for this input shape, fall back to paddle backend" - ) - self.backend = "paddle" - - if self.backend == "transformer_engine": - return self._te_forward(inp, mask, scale) - if self.backend == "paddle": - return self._pd_forward(inp, mask, scale) - raise AttributeError(f"Backend {self.backend} is not supported.") - - def is_kernel_available(self, b: int, h: int, s_q: int, s_kv: int) -> bool: - """Check FusedScaleMaskSoftmax kernel availability based on size""" - attn_batches = b * h - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_16bit_float # input must be fp16 - and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048 - and s_q % 4 == 0 # s_q must be a multiple of 4 - and attn_batches % 4 == 0 # b * h must be a multiple of 4 - ): - if 0 <= s_kv <= 4096: - batch_per_block = self.get_batch_per_block(int(s_kv)) - - if self.attn_mask_type == "causal": - if attn_batches % batch_per_block == 0: - return True - else: - if s_q % batch_per_block == 0: - return True - return False - - def _te_forward( - self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None - ) -> paddle.Tensor: - """Fused masked softmax kernel""" - b, h, s_q, s_kv = inp.size() - scale = 1.0 if scale is None else scale - - if self.attn_mask_type == "causal": - assert s_q == s_kv, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, s_q, s_kv) - inp = inp.reshape((-1, s_q, s_kv)) - probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale) - return probs.reshape((b, h, s_q, s_kv)) - # input is 4D tensor (b, h, s_q, s_kv) - if mask is not None: - return ScaledMaskedSoftmax.apply(inp, mask, scale) - return ScaledSoftmax.apply(inp, scale) - - def _pd_forward( - self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None - ) -> paddle.Tensor: - """Call Paddle OP""" - if self.input_in_16bit_float and self.softmax_in_fp32: - inp = paddle.cast(inp, "float32") - - if scale is not None: - inp = inp * scale - - if self.attn_mask_type == "causal": - mask = _get_default_causal_mask(inp.shape[2]) - - mask_output = self.mask_func(inp, mask) if mask is not None else inp - probs = paddle.nn.functional.softmax(mask_output, axis=-1) - - if self.input_in_16bit_float and self.softmax_in_fp32: - if self.input_is_fp16: - probs = paddle.cast(probs, "float16") - else: - probs = paddle.cast(probs, "bfloat16") - - return probs - - @staticmethod - def get_batch_per_block(key_seq_len: int) -> int: - """Softmax utility""" - pow2 = 1 << (key_seq_len - 1).bit_length() - warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP - batches_per_warp = 2 if pow2 <= 128 else 1 - warps_per_block = THREADS_PER_BLOCK // warp_size - batches_per_block = warps_per_block * batches_per_warp - return batches_per_block diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py deleted file mode 100644 index feb79c0caa..0000000000 --- a/transformer_engine/paddle/layer/transformer.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Transformer""" - -from typing import Optional, Tuple, Union -import warnings - -import paddle -from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd - -from .layernorm_mlp import LayerNormMLP -from .layernorm import LayerNorm -from .attention import MultiHeadAttention -from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -from ..distributed import get_tp_group_and_world_size, track_rng_state - - -class TransformerLayer(paddle.nn.Layer): - r""" - TransformerLayer is made up of an attention block and a feedforward network (MLP). - This standard layer is based on the paper "Attention Is All You Need". - - Parameters - ---------- - hidden_size : int - size of each input sample. - ffn_hidden_size : int - intermediate size to which input samples are projected. - num_attention_heads : int - number of attention heads in the transformer layer. - num_gqa_groups : Optional[int], default = `None` - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the queries. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - layernorm_epsilon : float, default = 1e-5 - a value added to the denominator of layer normalization - for numerical stability. - hidden_dropout: float, default = 0.1 - dropout probability for the dropout op after FC2 layer. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - self_attn_mask_type: {'causal', 'padding'}, default = `causal` - type of attention mask passed into softmax operation. - apply_residual_connection_post_layernorm : bool, default = `False` - if set to `True`, residual connections are taken - from the output of layer norm (default is taken - from input of layer norm) - output_layernorm: bool, default = `False` - if set to `True`, layer normalization is applied on the output side, - after the final dropout-add. default behavior is to apply layer - normalization on the input side, before the QKV transformation. - layer_type: {'encoder', 'decoder'}, default = `encoder` - if set to `decoder`, an additional cross-attn block is added after self-attn. - This can be used for structures like `T5` Transformer in conjunction with the - `encoder` option. - normalization: {'LayerNorm', 'RMSNorm'}, default = `LayerNorm` - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - activation : str, default = 'gelu' - Type of activation used in MLP block. - Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'. - - params_dtype : paddle.dtype, default = `paddle.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel - whereas PROJ and FC2 is used as Row Parallel as described - `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - attention_dropout_rng_state_name : str, default = `local_seed` - Controls the rng state used for dropout on attention probs. The - specified rng should be set different seeds for different TP ranks. - It will be ignored if `set_parallel_mode` is False. - hidden_dropout_rng_state_name : str, default = `global_seed` - Controls the rng state used for dropout on hidden states. The - specified rng should be given the same seeds for different TP - ranks. It will be ignored if `set_parallel_mode` is False. The - specified name should be registered through - `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() - .add(rng_state_name, seed)`. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - num_attention_heads: int, - num_gqa_groups: Optional[int] = None, - layernorm_epsilon: float = 1e-5, - hidden_dropout: float = 0.1, - attention_dropout: float = 0.1, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - self_attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - apply_residual_connection_post_layernorm: bool = False, - output_layernorm: bool = False, - layer_type: str = "encoder", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - activation: str = "gelu", - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - attention_dropout_rng_state_name: str = "local_seed", - hidden_dropout_rng_state_name: str = "global_seed", - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype - self.output_layernorm = output_layernorm - self.layer_type = layer_type - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.self_attn_mask_type = self_attn_mask_type - self.set_parallel_mode = set_parallel_mode - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.sequence_parallel = self.tensor_parallel and sequence_parallel - self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name - # SP needs local seed for hidden dropout - if self.sequence_parallel and self.hidden_dropout_rng_state_name == "global_seed": - warnings.warn( - "RNG state for hidden dropout needs to be different across TP ranks. " - "Forcing hidden_dropout_rng_state_name to 'local_seed'" - ) - self.hidden_dropout_rng_state_name = "local_seed" - - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" - assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" - - attention_args = ( - hidden_size, - num_attention_heads, - attention_dropout, - layernorm_epsilon, - weight_attr, - bias_attr, - ) - common_attention_kwargs = { - "params_dtype": params_dtype, - "return_layernorm_output": apply_residual_connection_post_layernorm, - "normalization": normalization, - "zero_centered_gamma": zero_centered_gamma, - "set_parallel_mode": set_parallel_mode, - "sequence_parallel": self.sequence_parallel, - "max_sequence_length": max_sequence_length, - "tp_group": tp_group, - "num_gqa_groups": num_gqa_groups, - "fuse_wgrad_accumulation": fuse_wgrad_accumulation, - "rng_state_name": attention_dropout_rng_state_name, - "backend": backend, - } - - self.self_attention = MultiHeadAttention( - *attention_args, - **common_attention_kwargs, - attn_mask_type=self_attn_mask_type, - input_layernorm=not output_layernorm, - attention_type="self", - ) - - if layer_type == "decoder": - self.inter_attention = MultiHeadAttention( - *attention_args, - **common_attention_kwargs, - attn_mask_type="padding", - input_layernorm=True, - attention_type="cross", - ) - - self.layernorm_mlp = LayerNormMLP( - hidden_size, - ffn_hidden_size, - eps=layernorm_epsilon, - weight_attr=weight_attr, - bias_attr=bias_attr, - normalization=normalization, - activation=activation, - return_layernorm_output=apply_residual_connection_post_layernorm, - zero_centered_gamma=zero_centered_gamma, - set_parallel_mode=set_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=backend, - ) - - self.hidden_dropout = hidden_dropout - - if self.output_layernorm: - self.layernorm = LayerNorm( - hidden_size, - layernorm_epsilon, - weight_attr, - bias_attr, - zero_centered_gamma=zero_centered_gamma, - sequence_parallel=self.sequence_parallel, - backend=backend, - ) - - self.fused_dropout_add1 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - if self.layer_type == "decoder": - self.fused_dropout_add2 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - self.fused_dropout_add3 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - encoder_output: Optional[paddle.Tensor] = None, - enc_dec_attn_mask: Optional[paddle.Tensor] = None, - rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - recompute_core_attention: bool = False, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """ - Transformer Layer: attention block and a feedforward network (MLP) - - .. note:: - - Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` - is set to `"causal"`. - - Parameters - ---------- - hidden_states : paddle.Tensor - Input tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out self-attention softmax input. - encoder_output : Optional[paddle.Tensor], default = `None` - Output of the encoder block to be fed into the decoder block if using - `layer_type="decoder"`. - enc_dec_attn_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out inter-attention softmax input if using - `layer_type="decoder"`. - rotary_pos_emb : Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied - core_attention_bias_type: str, default = `no_bias` - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to set output tensors to 0 or not before use. - recompute_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - - if self.self_attn_mask_type != "causal" and attention_mask is not None: - assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - - assert core_attention_bias_type in ["no_bias"], ( - "Only no_bias is supported currently, " - f"but receive core_attention_bias_type = {core_attention_bias_type}" - ) - - # Self attention. - self_attention_outputs = self.self_attention( - hidden_states, - attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - rotary_pos_emb=rotary_pos_emb, - recompute_core_attention=recompute_core_attention, - is_first_microbatch=is_first_microbatch, - ) - - if self.apply_residual_connection_post_layernorm and not self.output_layernorm: - attention_output, residual = self_attention_outputs - else: - attention_output = self_attention_outputs - residual = hidden_states - - # dropoout add. - with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): - bda_output = self.fused_dropout_add1(attention_output, residual) - - # Cross attention. - if self.layer_type == "decoder": - inter_attention_outputs = self.inter_attention( - bda_output, - enc_dec_attn_mask, - encoder_output=encoder_output, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - recompute_core_attention=recompute_core_attention, - is_first_microbatch=is_first_microbatch, - ) - if self.apply_residual_connection_post_layernorm: - attention_output, residual = inter_attention_outputs - else: - attention_output = inter_attention_outputs - residual = bda_output - - with track_rng_state( - enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name - ): - bda_output = self.fused_dropout_add2(attention_output, residual) - - # MLP. - mlp_outputs = self.layernorm_mlp(bda_output, is_first_microbatch=is_first_microbatch) - if self.apply_residual_connection_post_layernorm: - mlp_output, residual = mlp_outputs - else: - mlp_output = mlp_outputs - residual = bda_output - - # dropoout add. - with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): - output = self.fused_dropout_add3(mlp_output, residual) - - # For BERT like architectures. - if self.output_layernorm: - output = self.layernorm(output) - - # output: [b, s, hidden] - return output diff --git a/transformer_engine/paddle/profile.py b/transformer_engine/paddle/profile.py deleted file mode 100644 index d58679aea1..0000000000 --- a/transformer_engine/paddle/profile.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utils for profiling""" - -from contextlib import contextmanager - -try: - from paddle.base import core -except ImportError: - from paddle.fluid import core - - -@contextmanager -def nvtx_range(msg): - """Context to insert NVTX""" - core.nvprof_nvtx_push(msg) - yield - core.nvprof_nvtx_pop() diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py deleted file mode 100644 index 5551583736..0000000000 --- a/transformer_engine/paddle/recompute.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Methods needed for recompute.""" - -import os -import inspect - -from paddle.distributed import fleet - -from .constants import RecomputeFunctionNames -from .fp8 import get_global_fp8_state - - -__all__ = ["recompute"] - - -_DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) - - -def is_in_recompute_phase(): - """Inspect call stack to determine if this is called from - backward phase. Paddle has two recompute methods: - (1) Use RecomputeFunction. The recomputed function is called from `RecomputeFunction.backward`; - (2) Use paddle.autograd.saved_tensors_hooks. The recompute function is called from `unpack`.""" - if _DISABLE_RECOMPUTE: - return False - frame = inspect.currentframe().f_back - while frame: - if frame.f_code.co_name in RecomputeFunctionNames: - return True - frame = frame.f_back - return False - - -def recompute(function, *args, **kwargs): - """ - This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary - state information for fp8 layers. - - Parameters - ---------- - function: Callable - paddle module used to run the forward and backward passes using - the specified :attr:`args` and :attr:`kwargs`. - args : tuple - tuple of torch tensors for inputs to :attr:`function`. - kwargs : dict - dictionary of string keys for keyword arguments to :attr:`function`. - """ - assert ( - not _DISABLE_RECOMPUTE - ), f"Recompute is disabled. Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}." - - global_fp8_state = get_global_fp8_state() - - try: - global_fp8_state._fp8_recompute_enabled = True - outputs = fleet.utils.recompute(function, *args, **kwargs) - finally: - global_fp8_state._fp8_recompute_enabled = False - - return outputs diff --git a/transformer_engine/paddle/setup.py b/transformer_engine/paddle/setup.py deleted file mode 100644 index c80f21a01d..0000000000 --- a/transformer_engine/paddle/setup.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Installation script for TE paddle-paddle extensions.""" - -# pylint: disable=wrong-import-position,wrong-import-order - -import sys -import os -import shutil -from pathlib import Path - -import setuptools -from paddle.utils.cpp_extension import BuildExtension - -try: - import paddle # pylint: disable=unused-import -except ImportError as e: - raise RuntimeError("This package needs Paddle Paddle to build.") from e - - -current_file_path = Path(__file__).parent.resolve() -build_tools_dir = current_file_path.parent.parent / "build_tools" -if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir): - build_tools_copy = current_file_path / "build_tools" - if build_tools_copy.exists(): - shutil.rmtree(build_tools_copy) - shutil.copytree(build_tools_dir, build_tools_copy) - - -from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers -from build_tools.te_version import te_version -from build_tools.paddle import setup_paddle_extension - - -os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) - - -if __name__ == "__main__": - # Extensions - common_headers_dir = "common_headers" - copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) - ext_modules = [ - setup_paddle_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir - ) - ] - - # Configure package - setuptools.setup( - name="transformer_engine_paddle", - version=te_version(), - description="Transformer acceleration library - Paddle Paddle Lib", - ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["paddlepaddle-gpu>=2.6.1"], - tests_require=["numpy"], - ) - if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): - shutil.rmtree(common_headers_dir) - shutil.rmtree("build_tools") diff --git a/transformer_engine/paddle/utils.py b/transformer_engine/paddle/utils.py deleted file mode 100644 index 4a801495ab..0000000000 --- a/transformer_engine/paddle/utils.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utility functions for Transformer Engine modules""" - -from typing import Optional, Tuple, Union - -import paddle -import paddle.nn.functional as F -from .cpp_extensions import swiglu_pd - - -def cast_if_needed( - tensor: Union[paddle.Tensor, None], dtype: paddle.dtype -) -> Union[paddle.Tensor, None]: - """Cast tensor to dtype""" - return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype) - - -def cast_if_needed_inplace( - tensor: Union[paddle.Tensor, None], dtype: paddle.dtype -) -> Union[paddle.Tensor, None]: - """Cast tensor to dtype (inplace), not to be used on layer inputs""" - return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype) - - -def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool: - """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. - """ - return not tensor.shape[0] % 8 and not tensor.shape[1] % 16 - - -def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None: - """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. - """ - # single tensor check so it's clear which tensor is triggering the assertion - assert check_dim_for_fp8_forward_exec(tensor), ( - "Tensor dimensions are not compatible for FP8 execution: " - f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" - ) - - -def get_bias_dtype(activation_dtype: paddle.dtype): - """Get bias dtype given activation_dtype""" - return paddle.bfloat16 if activation_dtype == paddle.float32 else activation_dtype - - -def get_paddle_act_func(activation): - """Get paddle activation function""" - funcs = { - "gelu": F.gelu, - "relu": F.relu, - "silu": F.silu, - "swiglu": swiglu_pd, - } - if activation not in funcs: - raise "Activation type " + activation + " is not supported." - return funcs[activation] - - -def attention_mask_func( - attention_scores: paddle.Tensor, attention_mask: paddle.Tensor -) -> paddle.Tensor: - """Get attention mask""" - - def _masked_fill(x, mask, value): - y = paddle.full(x.shape, value, x.dtype) - return paddle.where(mask, y, x) - - attention_scores = _masked_fill(attention_scores, attention_mask, -10000.0) - return attention_scores - - -def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor: - """Convert mask to cu_seqlens""" - assert "bool" in str(mask.dtype), "mask must be bool dtype" - assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]" - q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype="int32") - q_cu_seqlens = paddle.cumsum(q_actual_seqlens) - q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0) - if not need_kv: - return q_cu_seqlens, None - kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype="int32") - kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens) - kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0) - return q_cu_seqlens, kv_cu_seqlens - - -def divide(numerator: int, denominator: int) -> int: - """Ensure that numerator is divisible by the denominator and return - the division value.""" - assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" - return numerator // denominator - - -def save_for_backward_allow_none(ctx, *args) -> None: - """Save tensors for backward. Args could be None""" - indices_mapping = [] - tensors_to_save = [] - for x in args: - if isinstance(x, paddle.Tensor): - indices_mapping.append(len(tensors_to_save)) - tensors_to_save.append(x) - elif x is None: - indices_mapping.append(-1) - else: - raise ValueError(f"Type {type(x)} is not allowed.") - - ctx._indices_mapping = indices_mapping - ctx.save_for_backward(*tensors_to_save) - - -def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]: - """Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx.""" - assert hasattr( - ctx, "_indices_mapping" - ), "`saved_tensor_allow_none` must be used with `save_for_backward_allow_none` in pair." - - indices_mapping = ctx._indices_mapping - outputs = [] - saved_tensors = ctx.saved_tensor() - - for index in indices_mapping: - if index < 0: - outputs.append(None) - else: - outputs.append(saved_tensors[index]) - - return tuple(outputs) - - -def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None: - """ - Free tensor buffer - """ - - def can_free(t): - return ( - t is not None - and isinstance(t, paddle.Tensor) - and t._is_initialized() - and t.inplace_version == 0 - ) - - for t in tensors: - if can_free(t): - t._clear_dataptr() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 91d3772fd7..57addca3b9 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -82,27 +82,12 @@ def _load_library(): from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables -from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers -# Register custom op symbolic ONNX functions -from transformer_engine.pytorch.te_onnx_extensions import ( - onnx_cast_to_fp8, - onnx_cast_to_fp8_noalloc, - onnx_cast_from_fp8, - onnx_fp8_gelu, - onnx_fp8_relu, - onnx_te_gemm, - onnx_layernorm_fwd_fp8, - onnx_layernorm_fwd, - onnx_rmsnorm_fwd, - onnx_rmsnorm_fwd_fp8, -) - try: torch._dynamo.config.error_on_nested_jit_trace = False except AttributeError: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ccceacff85..bf6adc309c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -24,15 +24,7 @@ import transformer_engine_torch as tex import transformer_engine as te from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, -) from transformer_engine.pytorch.cpp_extensions.fused_attn import ( - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, fused_attn_fwd, fused_attn_bwd, QKVLayout, @@ -54,6 +46,7 @@ get_fp8_torch_dtype, ) from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( @@ -82,9 +75,13 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + prepare_for_saving, + restore_from_saved, +) # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 @@ -116,7 +113,8 @@ def _get_supported_versions(version_min, version_max): _flash_attn_is_installed = False _flash_attn_version = PkgVersion("0") _flash_attn_version_required = PkgVersion("2.1.1") -_flash_attn_max_version = PkgVersion("2.6.3") +_flash_attn_version_required_blackwell = PkgVersion("2.7.3") +_flash_attn_max_version = PkgVersion("2.7.3") _flash_attn_2_plus = False _flash_attn_2_1_plus = False _flash_attn_2_3_plus = False @@ -124,6 +122,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_4_1_plus = False _flash_attn_2_5_7_plus = False _flash_attn_2_6_0_plus = False +_flash_attn_2_7_0_plus = False flash_attn_cuda_bwd = None flash_attn_func = None @@ -142,7 +141,13 @@ def _get_supported_versions(version_min, version_max): """ "pip install flash-attn".""", ) else: - if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): + if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version: + _flash_attn_is_installed = True + elif _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + _flash_attn_is_installed = True + + if _flash_attn_is_installed: from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd @@ -154,7 +159,6 @@ def _get_supported_versions(version_min, version_max): _flash_attn_varlen_backward as _flash_attn_varlen_bwd, ) - _flash_attn_is_installed = True _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") @@ -162,13 +166,18 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") + _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0") elif ( torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN ): fa_logger.warning( "Supported flash-attn versions are %s. Found flash-attn %s.", _get_supported_versions( - _flash_attn_version_required, + ( + _flash_attn_version_required + if get_device_compute_capability() < (10, 0) + else _flash_attn_version_required_blackwell + ), _flash_attn_max_version, ), _flash_attn_version, @@ -181,11 +190,13 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False _use_flash_attn_3 = False +# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved +# https://github.com/Dao-AILab/flash-attention/issues/1452 _flash_attn_3_installation_steps = """\ -(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" +(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" (2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (3) mkdir -p $python_path/flashattn_hopper -(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" +(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" try: _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) except PackageNotFoundError: @@ -317,7 +328,7 @@ def __eq__(self, other): if fname != "fp8_meta": if sf != of: return False - elif sf["recipe"] != of["recipe"]: + elif sf.get("recipe", None) != of.get("recipe", None): return False return True @@ -434,15 +445,6 @@ def get_attention_backend( if not use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") - # Filter: ONNX mode - if is_in_onnx_export_mode(): - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to ONNX mode") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention due to ONNX mode") - use_fused_attention = False - # Filter: Compute capability if device_compute_capability < (8, 0): if use_flash_attention and _flash_attn_is_installed: @@ -937,7 +939,7 @@ def get_attention_backend( and use_fused_attention and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] ): - if device_compute_capability == (9, 0): + if device_compute_capability >= (9, 0): logger.debug( "Disabling FlashAttention to give FusedAttention preference on Hopper+ " "for performance reasons" @@ -1390,8 +1392,9 @@ def pack_tensor( indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) if isinstance(tensor, Float8Tensor): tensor_data = torch.cat((tensor._data, padding_indice), dim=0) + gathered_data = torch.gather(tensor_data, 0, indices) - packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices)) + packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape) else: tensor = torch.cat((tensor, padding_indice), dim=0) @@ -1444,7 +1447,8 @@ def unpack_tensor( ) if isinstance(tensor, Float8Tensor): unpacked.scatter_(0, indices, tensor._data) - unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :]) + unpacked_data = unpacked[0:-1, :, :] + unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape) else: unpacked.scatter_(0, indices, tensor) unpacked = unpacked[0:-1, :, :] @@ -1746,6 +1750,49 @@ def flash_attn_a2a_communicate( return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): + """Get the list of quantizers used in attention from the quantizers list.""" + if not fp8: + num_of_nones = 8 if cp_specific_quantizers else 6 + return [None] * num_of_nones + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] + QKV_quantizer.internal = True + QKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = True + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True + dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] + dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_CP_quantizer.internal = True + O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] + O_CP_quantizer.set_usage(rowwise=True, columnwise=False) + + if cp_specific_quantizers: + return ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -1784,6 +1831,7 @@ def forward( cp_group, cp_global_ranks, cp_stream, + quantizers, ): # pylint: disable=missing-function-docstring if softmax_scale is None: @@ -1839,56 +1887,58 @@ def forward( cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - fused_attn_qkv_dtype = None fused_attn_backend = None - amax_per_step = None qkv_dtype = q.dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + is_output_fp8 = False + if fp8: + is_output_fp8 = fp8_meta["recipe"].fp8_mha + + ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) = get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + if fp8: if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] + assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: + if not is_input_fp8: q_f16, k_f16, v_f16 = q, k, v if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + q = QKV_quantizer(q_f16) if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [k_f16, v_f16] - ] + k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer # partial result quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: q_f16 = q if use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if fp8: + q = q._data + k = k._data + v = v._data + if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) + q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) @@ -1896,7 +1946,7 @@ def forward( q_f16 = q elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16 = q - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + q = QKV_quantizer(q_f16)._data assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 @@ -1953,12 +2003,17 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_3_plus: + if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if _flash_attn_2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 # Flash Attn inputs q_inputs = [None, None] @@ -2007,17 +2062,7 @@ def forward( kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = cast_to_fp8( - p2p_comm_buffers[i], - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - ) - if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step - fp8_meta_kwargs["amax_s_offset"] = i - fp8_meta_kwargs["amax_o"] = amax_per_step - fp8_meta_kwargs["amax_o_offset"] = cp_size + i + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i]) if causal: if i == 0: if pad_between_seqs_q: @@ -2058,25 +2103,40 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, - fused_attn_backend, + q_part, + k_part, + v_part, + fake_dtype=qkv_dtype, + fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, @@ -2117,10 +2177,16 @@ def forward( causal=True, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] elif i <= rank: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2160,24 +2226,38 @@ def forward( if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv // 2, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2207,8 +2287,13 @@ def forward( max_seqlen_q, max_seqlen_kv // 2, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_forward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -2225,10 +2310,16 @@ def forward( causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2277,24 +2368,38 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q // 2, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2324,8 +2429,13 @@ def forward( max_seqlen_q // 2, max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_forward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -2342,10 +2452,16 @@ def forward( causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2374,24 +2490,38 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2433,10 +2563,16 @@ def forward( causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] if i > 0: # wait until fwd restuls correction of last step is done @@ -2454,13 +2590,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: - out_per_step[i - 1] = cast_from_fp8( - out_per_step[i - 1], - fp8_meta["scaling_fwd"], - META_O_CP, - fp8_dtype_forward, - TE_DType[torch.float32], - ) + out_per_step[i - 1] = out_per_step[i - 1].dequantize() if i == 1: out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) @@ -2562,70 +2692,48 @@ def forward( elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) - if fp8 and use_fused_attention: - amax_cp_fwd = amax_per_step.amax(dim=1) - fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] - fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] - out_fp8 = None out_f16 = out.to(qkv_dtype) + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) - - if fp8 and is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv_dtype, - ) - else: - out_ret = out_f16 + out_fp8 = O_quantizer(out_f16) # final result + + out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8 - fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() - fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + q_save, kv_save, out_save = q, kv, out_fp8._data elif fp8 and is_input_fp8: - q_fp8 = Float8Tensor( - data=q, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, - ) - kv_fp8 = Float8Tensor( - data=kv, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=k_fp8.dtype, - ) - q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 - fp8_fwd_scales, fp8_fwd_scale_invs = None, None + q_save, kv_save, out_save = q, k, out_f16 else: q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 - fp8_fwd_scales, fp8_fwd_scale_invs = None, None - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( q_save, kv_save, out_save, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_fwd_scales, - fp8_fwd_scale_invs, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, *rng_states, *attn_biases, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.qkv_dtype = qkv_dtype + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.O_CP_quantizer = O_CP_quantizer + ctx.S_quantizer = S_quantizer + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dQKV_CP_quantizer = dQKV_CP_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.qkv_dtype = qkv_dtype + ctx.cp_group_a2a = cp_group_a2a ctx.cp_size_a2a = cp_size_a2a ctx.rank_a2a = rank_a2a @@ -2648,6 +2756,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + return out_ret @staticmethod @@ -2662,13 +2771,15 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (*saved_tensors,) = ctx.saved_tensors - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] - cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + saved_tensors = ctx.saved_tensors + + q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( + restore_from_saved(ctx.tensor_objects, saved_tensors) + ) + cu_seqlens_q_per_step = other_tensors[:cp_size] + cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] + rng_states = other_tensors[cp_size * 2 : cp_size * 3] + attn_biases = other_tensors[cp_size * 3 : cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2724,50 +2835,40 @@ def backward(ctx, dout): dq = None dout_dtype = dout.dtype fused_attn_backend = None - fused_attn_qkv_dtype = None fused_attn_dqkv_dtype = None - amax_per_step = None - dout_fp8_dtype = None if ctx.fp8: if ctx.use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8_ = torch.empty_like(dkv_fp8) if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + fused_attn_dqkv_dtype = dout._fp8_dtype dout = dout._data else: - dout = cast_to_fp8( - dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] - fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] - fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] - fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] - fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] - fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] - fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: if ctx.fp8_meta is not None and ctx.is_input_fp8: - q, kv = [x.from_float8(x.dtype) for x in [q, kv]] + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q, kv = q.dequantize(), kv.dequantize() if cp_size_a2a == 1: - dout = dout.from_float8(dout_dtype) - else: - dout_fp8_dtype = dout._fp8_dtype - dout_scale_inv = dout._scale_inv - dout = dout._data + dout = dout.dequantize() dq = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2776,7 +2877,6 @@ def backward(ctx, dout): p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] @@ -2795,14 +2895,9 @@ def backward(ctx, dout): True, ) if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = cast_from_fp8( - dout, - None, - None, - dout_fp8_dtype, - TE_DType[dout_dtype], - scale_inv=dout_scale_inv, # pylint: disable=used-before-assignment - ) + dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True) + dout = dout.dequantize() + dout = dout._data out = out.view(*q.shape) dout = dout.view(*q.shape) @@ -2827,6 +2922,8 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): # wait until KV is received @@ -2868,9 +2965,6 @@ def backward(ctx, dout): kv = p2p_comm_buffers[i % 2][0] q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None - if ctx.fp8 and ctx.use_fused_attention: - fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -2899,17 +2993,39 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -2923,6 +3039,10 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -2934,8 +3054,13 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, 0) + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -2981,17 +3106,39 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3007,6 +3154,10 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -3018,8 +3169,13 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) + if _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -3066,17 +3222,40 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3092,6 +3271,11 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -3103,8 +3287,13 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -3129,17 +3318,39 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], - out, - dout, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3153,6 +3364,12 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: dq_ = torch.empty_like(q) dkv_ = torch.empty_like(kv) @@ -3164,8 +3381,11 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -3328,23 +3548,13 @@ def backward(ctx, dout): dkv.add_(dkv_) if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] if ctx.qkv_format in ["bshd", "sbhd"]: # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dq, dkv = [ - cast_from_fp8( - x, - ctx.fp8_meta["scaling_bwd"], - META_DQKV_CP, - fp8_dtype_backward, - TE_DType[torch.float32], - ) - for x in [dq_fp8, dkv_fp8] - ] + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8) + dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8) + dq, dkv = [x.dequantize() for x in [dq, dkv]] dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: @@ -3364,10 +3574,8 @@ def backward(ctx, dout): dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dkv = [ - cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) - for x in [dq, dkv] - ] + assert torch.uint8 not in [dq.dtype, dkv.dtype] + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] dk, dv = dkv[0], dkv[1] if cp_size_a2a > 1: @@ -3386,22 +3594,14 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = [ - Float8Tensor( - data=x, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=dout_dtype, - ) - for x in [dq, dk, dv] - ] - if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + # converting torch.uint8 to float8tensor + if ctx.fp8 and ctx.is_input_fp8: + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) return ( None, @@ -3427,6 +3627,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3493,6 +3694,8 @@ def forward( cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + qkv_dtype = q.dtype + causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type assert not padding, f"{attn_mask_type} mask type is not supported!" @@ -3521,8 +3724,10 @@ def forward( fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if _flash_attn_2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3610,7 +3815,7 @@ def forward( q_, k_, v_, - TE_DType[q.dtype], + qkv_dtype, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, @@ -3631,19 +3836,31 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): + fa_forward_kwargs["window_size"] = window_size_per_step[i] + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( q_, k_, v_, *fa_forward_args_thd, causal=causal, - window_size=window_size_per_step[i], **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3673,6 +3890,8 @@ def forward( *softmax_lse_per_step, *rng_states, ) + + ctx.qkv_dtype = qkv_dtype ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group @@ -3754,6 +3973,8 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -3783,7 +4004,7 @@ def backward(ctx, dout): v_, out_, dout_, - TE_DType[q.dtype], + ctx.qkv_dtype, TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, @@ -3811,6 +4032,11 @@ def backward(ctx, dout): ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] + if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size"] = window_size_per_step[i] + if _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( dout_, q_, @@ -3823,7 +4049,6 @@ def backward(ctx, dout): dv_per_step[i], *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, - window_size=window_size_per_step[i], **fa_backward_kwargs, ) @@ -3923,12 +4148,14 @@ def forward( fp8_meta, cp_group, cp_stream, + quantizers, ): # pylint: disable=missing-function-docstring if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) + qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3958,12 +4185,17 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_3_plus: + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if _flash_attn_2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 assert ( q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 @@ -3978,50 +4210,38 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" - qkv_dtype = q.dtype fused_attn_backend = None - fused_attn_qkv_dtype = None # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + is_output_fp8 = False if fp8: - if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward + is_output_fp8 = fp8_meta["recipe"].fp8_mha + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) + if fp8: + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v - q, k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [q_f16, k_f16, v_f16] - ] + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O - fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta_kwargs["amax_s_offset"] = META_S - fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta_kwargs["amax_o_offset"] = META_O + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: if use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) @@ -4031,23 +4251,31 @@ def forward( if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v - q, k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [q_f16, k_f16, v_f16] - ] + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] batch_size = q.shape[batch_dim] if use_fused_attention: + q_part, k_part, v_part = q, k, v + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v, fake_dtype=qkv_dtype, internal=True + ) out, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, - k, - v, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -4060,6 +4288,8 @@ def forward( window_size=window_size, **fp8_meta_kwargs, ) + if fp8: + out = out._data else: fa_forward_args_thd = [] if qkv_format == "thd": @@ -4077,8 +4307,12 @@ def forward( causal=causal, **fa_forward_kwargs, ) - out, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + if not _flash_attn_2_7_0_plus: + out, softmax_lse = fa_outputs[4], fa_outputs[5] + rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + else: + out, softmax_lse = fa_outputs[0], fa_outputs[1] + rng_state = fa_outputs[3] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) @@ -4096,24 +4330,16 @@ def forward( if fp8: if is_output_fp8: - out_fp8 = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv_dtype, + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=False ) - out = out_fp8._data out_ret = out_fp8 + out = out_fp8._data else: - out_f16 = cast_from_fp8( - out, - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - TE_DType[q_f16.dtype], + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=False ) + out_f16 = out_fp8.dequantize() out_ret = out_f16 else: out_ret = out @@ -4122,30 +4348,22 @@ def forward( if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out elif is_input_fp8: - q_fp8, k_fp8, v_fp8 = [ - Float8Tensor( - data=x, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=out_fp8.dtype, - ) - for x in [q, k, v] - ] - q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8 + q_fp8 = QKV_quantizer.create_tensor_from_data( + q, fake_dtype=qkv_dtype, internal=False + ) + k_fp8 = QKV_quantizer.create_tensor_from_data( + k, fake_dtype=qkv_dtype, internal=False + ) + v_fp8 = QKV_quantizer.create_tensor_from_data( + v, fake_dtype=qkv_dtype, internal=False + ) + q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out else: q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 else: q_save, k_save, v_save, out_save = q, k, v, out - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() - fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() - else: - fp8_fwd_scales, fp8_fwd_scale_invs = None, None - - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( q_save, k_save, v_save, @@ -4154,10 +4372,20 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_fwd_scales, - fp8_fwd_scale_invs, *aux_ctx_tensors, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.qkv_dtype = qkv_dtype + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.qkv_dtype = qkv_dtype + ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream @@ -4182,11 +4410,18 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - q, k, v, out = saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] - fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] - aux_ctx_tensors = saved_tensors[10:] + ( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *aux_ctx_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + dout_dtype = dout.dtype qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type @@ -4194,47 +4429,32 @@ def backward(ctx, dout): fused_attn_backend = None fused_attn_dqkv_dtype = None - fused_attn_qkv_dtype = None - dout_dtype = dout.dtype if ctx.fp8: + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_dqkv_dtype = fp8_dtype_backward + if ctx.use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv dout_fp8 = dout dout = dout_fp8._data else: dout_f16 = dout - dout = cast_to_fp8( - dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + dout = ctx.dO_quantizer(dout_f16)._data fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] - fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] - fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] - fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] - fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] - fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] - fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV] - fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] - fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][ - META_DQKV - ] + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + else: assert False, "FP8 is only supported with Fused Attention!" else: if ctx.fp8_meta is not None and ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] + q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]] if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] @@ -4263,25 +4483,53 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_3_plus: + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = ctx.window_size + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: + q_part = q + k_part = k + v_part = v + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dq, dk, dv, _ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, - k, - v, - out, - dout, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -4296,6 +4544,10 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq = dq._data + dk = dk._data + dv = dv._data else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] @@ -4335,29 +4587,11 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - if ctx.is_input_fp8: - dq, dk, dv = [ - Float8Tensor( - data=x, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=dout_dtype, - ) - for x in [dq, dk, dv] - ] - else: - dq, dk, dv = [ - cast_from_fp8( - x, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - TE_DType[dout_dtype], - ) - for x in [dq, dk, dv] - ] + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + if not ctx.is_input_fp8: + dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] return ( None, @@ -4384,6 +4618,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4413,6 +4648,7 @@ def attn_forward_func_with_cp( window_size=None, fp8=False, fp8_meta=None, + quantizers=None, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -4480,7 +4716,7 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) @@ -4488,7 +4724,7 @@ def attn_forward_func_with_cp( args += [window_size, cp_group, cp_stream] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream] + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -4720,15 +4956,34 @@ def forward( mixed_x_layer: torch.Tensor, split_dim: int, split_size_or_sections: Union[int, List[int], Tuple[int]], + squeeze=False, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections + if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( + mixed_x_layer, Float8Tensor + ): + return tuple( + Float8TensorBase( + fp8_scale_inv=mixed_x_layer._scale_inv, + fp8_dtype=mixed_x_layer._fp8_dtype, + data=x.squeeze(split_dim) if squeeze else x, + shape=x.squeeze(split_dim).shape if squeeze else x.shape, + quantizer=mixed_x_layer._quantizer, + ) + for x in torch.split( + mixed_x_layer._data, + split_size_or_sections=split_size_or_sections, + dim=split_dim, + ) + ) if isinstance(mixed_x_layer, Float8Tensor): return tuple( Float8Tensor.make_like( mixed_x_layer, - data=x, + data=x.squeeze(split_dim) if squeeze else x, + shape=x.squeeze(split_dim).shape if squeeze else x.shape, ) for x in torch.split( mixed_x_layer._data, @@ -4736,7 +4991,10 @@ def forward( dim=split_dim, ) ) - return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) + out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) + if squeeze: + out_list = [x.squeeze(split_dim) for x in out_list] + return out_list @staticmethod def backward(ctx, *grad_outputs): @@ -4782,13 +5040,17 @@ def backward(ctx, *grad_outputs): new_shape, strides, ) - return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None + return ( + Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape), + None, + None, + ) grad_outputs_data = [x._data for x in grad_outputs] + data = torch.cat(grad_outputs_data, dim=split_dim) return ( - Float8Tensor.make_like( - grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim) - ), + Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape), + None, None, None, ) @@ -4923,19 +5185,14 @@ def forward( key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] - # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator - is_bf16 = query_layer.dtype == torch.bfloat16 matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], - dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype, + dtype=query_layer.dtype, device=torch.cuda.current_device(), ) - if is_in_onnx_export_mode() and is_bf16: - matmul_result = matmul_result.bfloat16() - scale = self.softmax_scale if apply_qk_layer_scaling: scale /= self.layer_number @@ -5323,6 +5580,7 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, ) -> torch.Tensor: """flash-attn fprop""" @@ -5373,7 +5631,7 @@ def forward( for x in (query_layer._data, key_layer._data, value_layer._data) ] query_layer, key_layer, value_layer = [ - Float8Tensor.make_like(x, data=x._data) + Float8Tensor.make_like(x, data=x._data, shape=x._data.shape) for x in (query_layer, key_layer, value_layer) ] if context_parallel: @@ -5476,6 +5734,7 @@ def forward( attn_mask_type=attn_mask_type, deterministic=self.deterministic, window_size=window_size, + quantizers=quantizers, ) else: @@ -5514,10 +5773,10 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["deterministic"] = self.deterministic - activation_dtype = query_layer.dtype if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + torch_orig_dtype = query_layer.dtype def convert_to_torch_float8(tensor, dtype): out = torch.Tensor().to(device=tensor.device, dtype=dtype) @@ -5534,960 +5793,118 @@ def convert_to_torch_float8(tensor, dtype): assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." - if isinstance(query_layer, Float8Tensor): - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv - else: + if not isinstance(query_layer, Float8Tensor): query_layer, key_layer, value_layer = ( - Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) - for x in [query_layer, key_layer, value_layer] + QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] ) - fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv - fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv - fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv - query_layer, key_layer, value_layer = ( - convert_to_torch_float8(x, torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) - try: - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_3_optional_forward_kwargs, + fa_3_optional_forward_kwargs["descale_q"] = ( + query_layer._scale_inv.unsqueeze(0) ) - except TypeError as e: - if _flash_attn_3_0_0_beta: - e.args = ( - e.args[0] - + ". Please update your flash-attn v3 (beta) installation as it " - + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, - ) + e.args[1:] - raise - - if fp8 and fp8_meta["recipe"].fp8_mha: - output = cast_to_fp8( - output, - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( + 0 ) - output = Float8Tensor( - data=output, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, + fa_3_optional_forward_kwargs["descale_v"] = ( + value_layer._scale_inv.unsqueeze(0) ) - else: - output = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) - - if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: - output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) - - if qkv_format == "sbhd": - # (bs)hd -> bs(hd) -> sb(hd) - if fp8 and fp8_meta["recipe"].fp8_mha: - output = Float8Tensor.make_like( - output, - data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) - .transpose(0, 1) - .contiguous(), - ) - else: - output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) - elif qkv_format == "bshd": - # (bs)hd -> bs(hd) - output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) - elif qkv_format == "thd": - # thd -> t(hd) - output = output.reshape(output.shape[0], -1) - - return output.contiguous() - - -def _combine_tensors( - tensors: List[torch.Tensor], - dim: int, -) -> torch.Tensor: - """Combine tensors along a particular dimension""" - - num_tensors = len(tensors) - new_shape = list(tensors[0].shape) - new_shape.insert(dim, num_tensors) - new_stride = list(tensors[0].stride()) - new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) - if isinstance(tensors[0], Float8Tensor): - combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype) - combined_tensor.set_( - tensors[0]._data.untyped_storage(), - tensors[0]._data.storage_offset(), - new_shape, - new_stride, - ) - combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor) - else: - combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype) - combined_tensor.set_( - tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride - ) - - return combined_tensor - - -class FusedAttnFunc_qkvpacked(torch.autograd.Function): - """Function for FusedAttention with packed QKV input""" - - @staticmethod - def forward( - ctx, - is_training, - max_seqlen, - cu_seqlens, - cu_seqlens_padded, - qkv, - qkv_dtype, - attn_bias, - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - fused_attention_backend, - use_FAv2_bwd, - fp8, - fp8_meta, - deterministic, - ): - # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha - if fp8: - is_input_fp8 = isinstance(qkv, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - assert ( - qkv_group == 1 - ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." - if is_input_fp8: - qkv_fp8 = qkv._data - else: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8( - qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(qkv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, - max_seqlen, - cu_seqlens, - qkv_fp8, - fp8_dtype_forward, - fused_attention_backend, - attn_bias, - cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - out_save = out_ret - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - if is_input_fp8: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - fp8_tensors = ( - qkv_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) - else: - out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, - max_seqlen, - cu_seqlens, - qkv, - qkv_dtype, - fused_attention_backend, - attn_bias, - cu_seqlens_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - fp8_tensors = (None, None, None, None) - out_save = out_ret - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) - ctx.save_for_backward( - *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors - ) - ctx.fp8_meta = fp8_meta - ctx.max_seqlen = max_seqlen - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.window_size = window_size - ctx.fused_attention_backend = ( - fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] - ) - ctx.use_FAv2_bwd = use_FAv2_bwd - ctx.deterministic = deterministic - - return out_ret - - @staticmethod - def backward(ctx, d_out): - # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data - - d_out = d_out.contiguous() - ( - qkv, - out, - cu_seqlens, - cu_seqlens_padded, - qkv_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors - rest = [None] - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() - if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors - dqkv = torch.empty_like(qkv) - d_out, q, k, v, out = [ - maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out) - ] - flash_attn_cuda_bwd( - d_out, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.attn_scale, - False, - "causal" in ctx.attn_mask_type, - None, - rng_state, - ) - dqkv = dqkv[..., : d_out.shape[-1]] - else: - with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) - if ctx.is_output_fp8: - d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) - dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, - cu_seqlens, - qkv_fp8, - out_fp8, - d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - if ctx.is_input_fp8: - dqkv = Float8Tensor( - data=dqkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dqkv_c_fp8 = dqkv_fp8.view( - -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] - ) - dqkv = cast_from_fp8( - dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dqkv_fp8.shape) - else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(qkv.dtype) - dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, - cu_seqlens, - qkv, - out, - d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_padded, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - dqkv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) - return ( - None, - None, - None, - None, - dqkv, - None, - rest[0], - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class FusedAttnFunc_kvpacked(torch.autograd.Function): - """Function for FusedAttention with packed KV input""" - - @staticmethod - def forward( - ctx, - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q, - kv, - qkv_dtype, - attn_bias, - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - fused_attention_backend, - use_FAv2_bwd, - fp8, - fp8_meta, - deterministic, - ): - # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha - if fp8: - assert isinstance(kv, q.__class__), "q and kv must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if is_input_fp8: - q_fp8, kv_fp8 = q._data, kv._data - else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f"but found {qkv_layout}." - ) - q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( - q.shape - ) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8( - kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(kv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - kv_fp8, - fp8_dtype_forward, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - out_save = out_ret - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - if is_input_fp8: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - fp8_tensors = ( - q_fp8, - kv_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) - else: - out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - out_save = out_ret - fp8_tensors = (None, None, None, None, None) - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) - ctx.save_for_backward( - *qkvo_tensors, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *fp8_tensors, - *aux_ctx_tensors, - ) - ctx.fp8_meta = fp8_meta - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.window_size = window_size - ctx.fused_attention_backend = ( - fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] - ) - ctx.use_FAv2_bwd = use_FAv2_bwd - ctx.deterministic = deterministic - - return out_ret - - @staticmethod - def backward(ctx, d_out): - # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data - - d_out = d_out.contiguous() - ( - q, - kv, - out, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q_fp8, - kv_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors - rest = [None] - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() - if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)] - flash_attn_cuda_bwd( - d_out, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.attn_scale, - False, - "causal" in ctx.attn_mask_type, - None, - rng_state, - ) - dq = dq[..., : d_out.shape[-1]] - dkv = dkv[..., : d_out.shape[-1]] - else: - with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) - if ctx.is_output_fp8: - d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) - dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - kv_fp8, - out_fp8, - d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - if ctx.is_input_fp8: - dq = Float8Tensor( - data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dkv = Float8Tensor( - data=dkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] + query_layer, key_layer, value_layer = ( + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] ) - dkv = cast_from_fp8( - dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dkv_fp8.shape) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_3_optional_forward_kwargs, + ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your flash-attn v3 (beta) installation as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise + + if fp8: + output = output.to(dtype=torch_orig_dtype) + if fp8 and fp8_meta["recipe"].fp8_mha: + O_quantizer = quantizers["scaling_fwd"][META_O] + output = O_quantizer(output) else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) - dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - out, - d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, ) - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - dq, - dkv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) - return ( - None, - None, - None, - None, - None, - None, - None, - dq, - dkv, - None, - rest[0], - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: + output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + + if qkv_format == "sbhd": + # (bs)hd -> bs(hd) -> sb(hd) + if fp8 and fp8_meta["recipe"].fp8_mha: + output_data = ( + output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous() + ) + output = Float8Tensor.make_like( + output, + data=output_data, + shape=output_data.shape, + ) + else: + output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) + elif qkv_format == "bshd": + # (bs)hd -> bs(hd) + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) + elif qkv_format == "thd": + # thd -> t(hd) + output = output.reshape(output.shape[0], -1) + + return output.contiguous() + + +def _combine_tensors( + tensors: List[torch.Tensor], + dim: int, +) -> torch.Tensor: + """Combine tensors along a particular dimension""" + + num_tensors = len(tensors) + new_shape = list(tensors[0].shape) + new_shape.insert(dim, num_tensors) + if isinstance(tensors[0], Float8Tensor): + new_stride = list(tensors[0]._data.stride()) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) + combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype) + combined_tensor.set_( + tensors[0]._data.untyped_storage(), + tensors[0]._data.storage_offset(), + new_shape, + new_stride, + ) + combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape) + else: + new_stride = list(tensors[0].stride()) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) + combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype) + combined_tensor.set_( + tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride ) + return combined_tensor + class FusedAttnFunc(torch.autograd.Function): """Function for FusedAttention with separate Q, K, V tensors""" @@ -6519,56 +5936,51 @@ def forward( use_FAv2_bwd, fp8, fp8_meta, + quantizers, deterministic, ): # pylint: disable=missing-function-docstring # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha + is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False + fake_dtype = q.dtype + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) if fp8: fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + q_fp8, k_fp8, v_fp8 = None, None, None if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data + q_fp8, k_fp8, v_fp8 = q, k, v else: # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8( - qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(qkv.shape) - q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1]) - q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] - if qkv_group == 2: - q_fp8 = cast_to_fp8( - q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8( - kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(kv.shape) - k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1]) - k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] - if qkv_group == 3: - q_fp8 = cast_to_fp8( - q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(q.shape) - k_fp8 = cast_to_fp8( - k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(k.shape) - v_fp8 = cast_to_fp8( - v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(v.shape) + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = QKV_quantizer(qkv) + q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) + case 2: + q_fp8 = QKV_quantizer(q) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = QKV_quantizer(kv_c) + k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) + case 3: + q_fp8 = QKV_quantizer(q) + k_fp8 = QKV_quantizer(k) + v_fp8 = QKV_quantizer(v) + case _: + raise "Invalid qkv_layout " + qkv_layout out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -6578,23 +5990,13 @@ def forward( q_fp8, k_fp8, v_fp8, - fp8_dtype_forward, + fake_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + S_quantizer, + O_quantizer, attn_scale, dropout_p, fast_zero_fill, @@ -6605,22 +6007,9 @@ def forward( rng_gen, ) if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) + out_ret = out_fp8 else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + out_ret = out_fp8.dequantize().view(out_fp8.shape) out_save = out_ret if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): @@ -6631,75 +6020,25 @@ def forward( dim = qkv_layout.find("3") qkv = _combine_tensors([q, k, v], dim) qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] + qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) + q = q.dequantize() dim = qkv_layout.split("_")[1].find("2") kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] + kv_no_fp8 = kv.dequantize() + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) + q = q.dequantize() + k = k.dequantize() + v = v.dequantize() if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - - fp8_tensors = ( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) + out_save = out_fp8.dequantize() + + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: + out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -6709,23 +6048,13 @@ def forward( q, k, v, - qkv_dtype, + fake_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset + None, # s_quantizer + None, # o_quantizer attn_scale, dropout_p, fast_zero_fill, @@ -6736,7 +6065,7 @@ def forward( rng_gen, ) out_save = out_ret - fp8_tensors = (None, None, None, None, None, None) + fp8_tensors = (None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) @@ -6758,18 +6087,27 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, *qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - *fp8_tensors, *aux_ctx_tensors, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta + + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.S_quantizer = S_quantizer + ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv + ctx.fake_dtype = fake_dtype ctx.qkv_dtype = qkv_dtype ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p @@ -6793,11 +6131,13 @@ def backward(ctx, d_out): assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data d_out = d_out.contiguous() ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, q, k, v, @@ -6806,14 +6146,11 @@ def backward(ctx, d_out): cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + + aux_ctx_tensors = other_tensors + if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() rest = [None] @@ -6850,20 +6187,10 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) if ctx.is_output_fp8: d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) + d_out_fp8 = ctx.dO_quantizer(d_out) dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6874,22 +6201,15 @@ def backward(ctx, d_out): v_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, + ctx.fake_dtype, + ctx.qkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.S_quantizer, + ctx.dP_quantizer, + ctx.dQKV_quantizer, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, @@ -6900,95 +6220,36 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.is_input_fp8: - dq = Float8Tensor( - data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dk = Float8Tensor( - data=dk_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dv = Float8Tensor( - data=dv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: + if not ctx.is_input_fp8: qkv_group = len(ctx.qkv_layout.split("_")) if qkv_group == 1: dim = ctx.qkv_layout.find("3") - dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim) - dqkv_c_fp8 = dqkv_fp8.view( - -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] + dqkv_fp8_data = _combine_tensors( + [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim ) - dqkv = cast_from_fp8( - dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dqkv_fp8.shape) - dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1]) - dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] + dqkv_fp8 = dq_fp8.make_like( + tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape + ) + dqkv = dqkv_fp8.dequantize() + dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) if qkv_group == 2: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) + dq = dq_fp8.dequantize() dim = ctx.qkv_layout.split("_")[1].find("2") dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim) dkv_c_fp8 = dkv_fp8.view( -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] ) - dkv = cast_from_fp8( - dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dkv_fp8.shape) - dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1]) - dk, dv = [x.squeeze(dim) for x in [dk, dv]] + dkv = dkv_c_fp8.dequantize() + dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True) if qkv_group == 3: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) - dk = cast_from_fp8( - dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dk_fp8.shape) - dv = cast_from_fp8( - dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dv_fp8.shape) + dq = dq_fp8.dequantize() + dk = dk_fp8.dequantize() + dv = dv_fp8.dequantize() + else: + dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) + if isinstance(d_out, QuantizedTensor): + d_out = d_out.dequantize() dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6999,7 +6260,7 @@ def backward(ctx, d_out): v, out, d_out, - ctx.qkv_dtype, + ctx.fake_dtype, ctx.qkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, @@ -7008,13 +6269,6 @@ def backward(ctx, d_out): None, None, None, - None, - None, - None, - None, - None, - None, - None, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, @@ -7055,6 +6309,7 @@ def backward(ctx, d_out): None, None, None, + None, ) # else, return (dqkv, dbias) return ( @@ -7085,6 +6340,7 @@ def backward(ctx, d_out): None, None, None, + None, ) @@ -7184,6 +6440,7 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -7321,6 +6578,7 @@ def forward( window_size=window_size, fp8=fp8, fp8_meta=fp8_meta, + quantizers=quantizers, ) else: with self.attention_dropout_ctx(): @@ -7349,6 +6607,7 @@ def forward( use_FAv2_bwd, fp8, fp8_meta, + quantizers, self.deterministic, ) @@ -7736,7 +6995,6 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, - is_first_microbatch: Optional[bool] = None, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -7906,27 +7164,13 @@ def forward( Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - * it also allows skipping gradient accumulation during the - first microbatch (since it is the first gradient being - produced) """ + with self.prepare_forward( query_layer, - is_first_microbatch, num_gemms=3, allow_non_contiguous=True, ) as query_layer: - if self.fp8: if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: @@ -8290,6 +7534,7 @@ def forward( max_seqlen_kv=max_seqlen_kv, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, ) if use_fused_attention: @@ -8358,6 +7603,7 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, ) from .cpu_offload import CPUOffloadEnabled @@ -8569,11 +7815,11 @@ def __init__( fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", @@ -9035,16 +8281,9 @@ def forward( # not qkv_weight_interleaved: # [sq, b, (np/ng + 2), ng, hn] # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn] - if not is_in_onnx_export_mode(): - query_layer, key_layer, value_layer = _SplitAlongDim.apply( - mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) - ) - else: - query_layer, key_layer, value_layer = torch.split( - mixed_x_layer, - (num_queries_per_key_value, 1, 1), - dim=split_dim, - ) + query_layer, key_layer, value_layer = _SplitAlongDim.apply( + mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) + ) if self.qkv_format == "thd": query_layer, key_layer, value_layer = ( @@ -9086,18 +8325,11 @@ def forward( mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # mixed_kv_layer --> 2 [sk, b, ng, hn] - if not is_in_onnx_export_mode(): - key_layer, value_layer = _SplitAlongDim.apply( - mixed_kv_layer, - split_dim, - mixed_kv_layer.shape[split_dim] // 2, - ) - else: - key_layer, value_layer = torch.split( - mixed_kv_layer, - mixed_kv_layer.shape[split_dim] // 2, - dim=split_dim, - ) + key_layer, value_layer = _SplitAlongDim.apply( + mixed_kv_layer, + split_dim, + mixed_kv_layer.shape[split_dim] // 2, + ) key_layer, value_layer = ( x.reshape( x.size(0), @@ -9208,10 +8440,10 @@ def forward( # =================== # Output. [sq, b, h] # =================== - projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, + fp8_grad=isinstance(context_layer, QuantizedTensor), ) if self.return_bias: diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index c1790313ac..ff475caf21 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -16,6 +16,8 @@ """ TE_DType = { torch.uint8: tex.DType.kByte, + torch.float8_e4m3fn: tex.DType.kFloat8E4M3, + torch.float8_e5m2: tex.DType.kFloat8E5M2, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, @@ -59,3 +61,5 @@ GemmParallelModes = ("row", "column", None) dist_group_type = torch.distributed.ProcessGroup + +MXFP8_BLOCK_SCALING_SIZE = 32 diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index be911fcd95..944d1849bf 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -7,8 +7,3 @@ from .fused_attn import * from .gemm import * -from .transpose import * -from .activation import * -from .normalization import * -from .cast import * -from .padding import * diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py deleted file mode 100644 index aec972994a..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Helper functions for C++ extensions""" -import functools -from typing import Dict, Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex - - -@functools.lru_cache(maxsize=None) -def empty_tensor() -> torch.Tensor: - """Get tensor with no entries and no data""" - return torch.Tensor() - - -def canonicalize_fp8_scales( - *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - fp8_meta: Optional[tex.FP8TensorMeta] = None, - fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, - allow_multiple_offsets: bool = True, -) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: - """Canonicalize FP8 scaling factors (scale, amax, scale-inverse) - - If a scaling factor is not provided, try to access it within the - FP8 meta tensors. Returns dict with tensors and dict with tensor - offsets. - - """ - - # Default: use provided scales with no offsets - scale_offset = 0 - amax_offset = 0 - scale_inv_offset = 0 - - # Get scales from FP8 meta tensors if needed - if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)): - if fp8_meta_index is None: - raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`") - fp8_meta_index = int(fp8_meta_index) - if scale is None: - scale = fp8_meta.scale - scale_offset = fp8_meta_index - if amax is None: - amax = fp8_meta.amax_history - amax_offset = fp8_meta_index - if scale_inv is None: - scale_inv = fp8_meta.scale_inv - scale_inv_offset = fp8_meta_index - - # Construct empty tensors if needed - if scale is None: - scale = empty_tensor() - scale_offset = 0 - if amax is None: - amax = empty_tensor() - amax_offset = 0 - if scale_inv is None: - scale_inv = empty_tensor() - scale_inv_offset = 0 - - # Force offsets to be the same if needed - if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: - if scale_offset != 0: - scale = scale[scale_offset:] - scale_offset = 0 - if amax_offset != 0: - amax = amax[:, amax_offset:] - amax_offset = 0 - if scale_inv_offset != 0: - scale_inv = scale_inv[scale_inv_offset:] - scale_inv_offset = 0 - - # Pack tensors and offsets into dicts - tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv} - offsets = { - "scale_offset": scale_offset, - "amax_offset": amax_offset, - "scale_inv_offset": scale_inv_offset, - } - return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py deleted file mode 100644 index 534e71d134..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for activation extensions""" -from typing import Optional, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - -__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] - - -def gelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """GeLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.gelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def relu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.relu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def geglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """GeGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.geglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def reglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.reglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def swiglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """SwiGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.swiglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def qgelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """QuickGELU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.qgelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def srelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.srelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py deleted file mode 100644 index 9c21edccec..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for cast extensions""" -from typing import Optional, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - -__all__ = ["cast_to_fp8", "cast_from_fp8"] - - -def cast_to_fp8( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Cast input to FP8""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch FP8 cast kernel - if inp.nelement() == 0: - if out is None: - out = torch.empty_like(inp, dtype=torch.uint8) - elif out is None: - out = torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - else: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_scales["scale"], - out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - return out - - -def cast_from_fp8( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - itype: tex.DType, - otype: tex.DType, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Cast input from FP8""" - - # Get scaling factors from FP8 meta tensors if needed - scale_inv_offset = 0 - if (fp8_meta_tensor is not None) and (scale_inv is None): - if fp8_tensor is None: - raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`") - scale_inv = fp8_meta_tensor.scale_inv - scale_inv_offset = int(fp8_tensor) - - # Construct empty tensors if needed - if scale_inv is None: - raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`") - - # Launch FP8 cast kernel - return torch.ops.tex_ts.cast_from_fp8_ts( - inp, - scale_inv, - scale_inv_offset, - itype, - otype, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 332b4e52ee..b91a6c1751 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -4,7 +4,7 @@ """Python interface for fused attention extensions""" import math -from typing import Tuple, List, Union +from typing import Tuple, List, Union, Optional import torch import transformer_engine_torch as tex from transformer_engine_torch import ( @@ -13,13 +13,10 @@ NVTE_Mask_Type, NVTE_Fused_Attn_Backend, ) +from ..tensor.quantized_tensor import Quantizer __all__ = [ - "fused_attn_fwd_qkvpacked", - "fused_attn_bwd_qkvpacked", - "fused_attn_fwd_kvpacked", - "fused_attn_bwd_kvpacked", "fused_attn_fwd", "fused_attn_bwd", ] @@ -89,803 +86,6 @@ META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 -def fused_attn_fwd_qkvpacked( - is_training: bool, - max_seqlen: int, - cu_seqlens: torch.Tensor, - qkv: torch.Tensor, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - attn_bias: torch.Tensor = None, - cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbh3d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - rng_gen: torch.Generator = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention FWD for packed QKV input. - - Parameters - ---------- - is_training: bool - if True, runs training and produces auxiliary tensors aux_ctx_tensors - for the backward; if False, runs inference and doesn't produce aux_ctx_tensors - max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(seqlens), - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - cu_seqlens: torch.Tensor - cumulative sequence lengths for QKV; shape [batch_size + 1] - qkv: torch.Tensor - input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv - cu_seqlens_padded: torch.Tensor, default = None - cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbh3d" - layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - rng_gen: torch.Generator, default = None - random number generator; - if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen - - Returns - ---------- - o: torch.Tensor - output tensor O, of the attention calculation; same data type as QKV; - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] - if is_training is False, aux_ctx_tensors = None - - softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen, max_seqlen], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - softmaxStats: torch.Tensor - log(sum(e^(x - max(x)))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen - state of the random number generator; - [seed, offset], dtype uint64 - """ - - if attn_scale is None: - d = qkv.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == qkv.dtype, "attn_bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") - - # execute kernel - output_tensors = tex.fused_attn_fwd_qkvpacked( - max_seqlen, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens, - qkv, - qkv_dtype, - cu_seqlens_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) - - # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] - - -def fused_attn_bwd_qkvpacked( - max_seqlen: int, - cu_seqlens: torch.Tensor, - qkv: torch.Tensor, - o: torch.Tensor, - d_o: torch.Tensor, - qkv_dtype: tex.DType, - dqkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor], - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbh3d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - deterministic: bool = False, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention BWD for packed QKV input. - - Parameters - ---------- - max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(seqlens) - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - cu_seqlens: torch.Tensor - cumulative sequence lengths for QKV; shape [batch_size + 1] - qkv: torch.Tensor - input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details) - o: torch.Tensor - input tensor O (output of forward); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - d_o: torch.Tensor - input tensor dO (gradient of O); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype - dqkv_dtype: tex.DType - data type of dQKV; in tex.DType, not torch.dtype - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - cu_seqlens_padded: torch.Tensor, default = None - cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQKV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQKV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbh3d" - layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - deterministic: bool, default = False - whether to execute the backward pass with deterministic behaviours. - - Returns - ---------- - d_qkv: torch.Tensor - gradient tensor of QKV; same data type and shape as QKV - d_bias: torch.Tensor, optional - gradient tensor of Bias when attn_bias_type is "pre_scale_bias" - or "post_scale_bias"; same data type and shape as Bias - """ - - if attn_scale is None: - d = qkv.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - - # execute kernel - output_tensors = tex.fused_attn_bwd_qkvpacked( - max_seqlen, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - deterministic, - cu_seqlens, - qkv, - o, - d_o, - qkv_dtype, - dqkv_dtype, - aux_ctx_tensors, - cu_seqlens_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, - ) - - return output_tensors - - -def fused_attn_fwd_kvpacked( - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - q: torch.Tensor, - kv: torch.Tensor, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - attn_bias: torch.Tensor = None, - cu_seqlens_q_padded: torch.Tensor = None, - cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbhd_sbh2d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - rng_gen: torch.Generator = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention FWD for packed KV input. - - Parameters - ---------- - is_training: bool - if True, runs training and produces auxiliary tensors aux_ctx_tensors - for the backward; if False, runs inference and doesn't produce aux_ctx_tensors - max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(seqlens_q), - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(seqlens_kv), - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - cu_seqlens_q: torch.Tensor - cumulative sequence lengths for Q; shape [batch_size + 1] - cu_seqlens_kv: torch.Tensor - cumulative sequence lengths for KV; shape [batch_size + 1] - q: torch.Tensor - input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details) - kv: torch.Tensor - packed input tensor KV; shape 2hd or h2d (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q and KV; in tex.DType, not torch.dtype - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv - cu_seqlens_q_padded: torch.Tensor, default = None - cumulative sequence offsets for Q; shape [batch_size + 1] - cu_seqlens_kv_padded: torch.Tensor, default = None - cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbhd_sbh2d" - layout of QKV; - {"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - rng_gen: torch.Generator, default = None - random number generator; - if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen - - Returns - ---------- - o: torch.Tensor - output tensor O, of the attention calculation; same data type as QKV; - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] - if is_training is False, aux_ctx_tensors = None - - softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - softmaxStats: torch.Tensor - log(sum(e^(x - max(x)))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen - state of the random number generator; - [seed, offset], dtype uint64 - """ - - if attn_scale is None: - d = q.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") - - # execute kernel - output_tensors = tex.fused_attn_fwd_kvpacked( - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - qkv_dtype, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) - - # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] - - -def fused_attn_bwd_kvpacked( - max_seqlen_q: int, - max_seqlen_kv: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - q: torch.Tensor, - kv: torch.Tensor, - o: torch.Tensor, - d_o: torch.Tensor, - qkv_dtype: tex.DType, - dqkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor], - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - cu_seqlens_q_padded: torch.Tensor = None, - cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbhd_sbh2d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - deterministic: bool = False, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention BWD for packed KV input. - - Parameters - ---------- - max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(seqlens_q), - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(seqlens_kv), - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - cu_seqlens_q: torch.Tensor - cumulative sequence lengths for Q; shape [batch_size + 1] - cu_seqlens_kv: torch.Tensor - cumulative sequence lengths for KV; shape [batch_size + 1] - q: torch.Tensor - input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details) - kv: torch.Tensor - packed input tensor KV; shape h2d or 2hd (see `qkv_layout` for details) - o: torch.Tensor - input tensor O (output of forward); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - d_o: torch.Tensor - input tensor dO (gradient of O); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q and KV; in tex.DType, not torch.dtype - dqkv_dtype: tex.DType - data type of dQ and dKV; in tex.DType, not torch.dtype - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - cu_seqlens_q_padded: torch.Tensor, default = None - cumulative sequence offsets for Q; shape [batch_size + 1] - cu_seqlens_kv_padded: torch.Tensor, default = None - cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQKV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations, - P = Q * K.T - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQKV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbhd_sbh2d" - layout of QKV; - {"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - deterministic: bool, default = False - whether to execute the backward pass with deterministic behaviours. - - Returns - ---------- - d_q: torch.Tensor - gradient tensor of Q; same data type and shape as Q - d_kv: torch.Tensor - gradient tensor of KV; same data type and shape as KV - d_bias: torch.Tensor, optional - gradient tensor of Bias when attn_bias_type is "pre_scale_bias" - or "post_scale_bias"; same data type and shape as Bias - """ - - if attn_scale is None: - d = q.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - - # execute kernel - output_tensors = tex.fused_attn_bwd_kvpacked( - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - deterministic, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - o, - d_o, - qkv_dtype, - dqkv_dtype, - aux_ctx_tensors, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, - ) - - return output_tensors - - def fused_attn_fwd( is_training: bool, max_seqlen_q: int, @@ -895,23 +95,13 @@ def fused_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - qkv_dtype: tex.DType, + fake_dtype: torch.dtype, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, + s_quantizer: Quantizer = None, + o_quantizer: Quantizer = None, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -946,8 +136,9 @@ def fused_attn_fwd( input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) v: torch.Tensor input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q, K and V; in tex.DType, not torch.dtype + fake_dtype: tex.DType + data type of Q, K and V - in case of high precision, fake dtype in case of FP8; + in torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. attn_bias: torch.Tensor, default = None @@ -957,30 +148,10 @@ def fused_attn_fwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O + s_quantizer: Quantizer, default = None + Quantizer object for the intermediate value S. + o_quantizer: Quantizer, default = None + Quantizer object for the output of the attention. attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -1068,17 +239,16 @@ def fused_attn_fwd( ) // BACKEND_F16m512_FP8_THREADS_PER_CTA assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention." + assert ( + o_quantizer is not None + ), "o_quantizer is required as an input for FP8 fused attention." else: raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel + output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -1095,21 +265,11 @@ def fused_attn_fwd( q, k, v, - qkv_dtype, + fake_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, + s_quantizer, + o_quantizer, attn_bias, rng_gen, rng_elts_per_thread, @@ -1129,23 +289,16 @@ def fused_attn_bwd( v: torch.Tensor, o: torch.Tensor, d_o: torch.Tensor, - qkv_dtype: tex.DType, + fake_dtype: torch.dtype, dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, + s_quantizer: Quantizer = None, + dp_quantizer: Quantizer = None, + dqkv_quantizer: Quantizer = None, + attn_scale: Optional[float] = None, dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", @@ -1181,8 +334,9 @@ def fused_attn_bwd( d_o: torch.Tensor input tensor dO (gradient of O); same data type as Q, K and V; same shape as Q - qkv_dtype: tex.DType - data type of Q, K and V; in tex.DType, not torch.dtype + fake_dtype: tex.DType + data type of Q, K and V - in case of high precision, fake dtype in case of FP8; + in torch.dtype dqkv_dtype: tex.DType data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors: List[torch.Tensor] @@ -1194,30 +348,12 @@ def fused_attn_bwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQ, dK and dV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations, - P = Q * K.T - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default + s_quantizer: Quantizer, default = None + Quantizer object for the intermediate value S. + dp_quantizer: Quantizer, default = None + Quantizer object for the intermediate value dP. + dqkv_quantizer: Quantizer, default = None + Quantizer object for the output values of the fused_attn_bwd. dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -1253,7 +389,6 @@ def fused_attn_bwd( gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias """ - if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -1268,21 +403,19 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." + assert ( + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention backward." + assert ( + dp_quantizer is not None + ), "dp_quantizer is required as an input for FP8 fused attention backward." + assert ( + dqkv_dtype is not None + ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( len(aux_ctx_tensors) == 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - # execute kernel output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -1301,21 +434,14 @@ def fused_attn_bwd( v, o, d_o, - qkv_dtype, + fake_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, + s_quantizer, + dp_quantizer, + dqkv_quantizer, ) return output_tensors diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index c55f5a9fd4..948a13a03e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -4,499 +4,223 @@ """Python interface for GEMM extensions""" import functools -from typing import Optional, Tuple, Union, List +from typing import Iterable, Optional, Tuple, Union, List +import os import torch import transformer_engine_torch as tex from ..constants import TE_DType -from ..utils import assert_dim_for_fp8_exec +from ..utils import assert_dim_for_fp8_exec, get_sm_count +from ..tensor.quantized_tensor import Quantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = [ - "gemm", - "fp8_gemm", - "grouped_gemm", - "fp8_grouped_gemm", + "general_gemm", + "general_grouped_gemm", ] @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" - return torch.Tensor() + return torch.Tensor().cuda() -def fp8_gemm( - A: torch.Tensor, - A_scale_inv: torch.Tensor, - A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - A_dtype: tex.DType, - B: torch.Tensor, - B_scale_inv: torch.Tensor, - B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - B_dtype: tex.DType, - out_dtype: torch.dtype, - workspace: torch.Tensor, - gelu: bool = False, - accumulate: bool = False, - out: Optional[torch.Tensor] = None, - out_index=None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - bias: Optional[torch.Tensor] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, - ub_algo: tex.CommOverlapAlgo = None, - ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - extra_output_tensor: torch.Tensor = None, -) -> torch.Tensor: - """TN layout GEMM with fp8 inputs.""" +def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str): + """Swizzle gemm inputs and return original scaling factor inverses.""" + if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase): + return None - empty_tensor = _empty_tensor() - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_index is not None - assert_dim_for_fp8_exec(A) - assert_dim_for_fp8_exec(B) - assert A.dtype == torch.uint8 - assert B.dtype == torch.uint8 - - if out is None: - out = torch.empty( - B.shape[0], - A.shape[0], - dtype=out_dtype, - device="cuda", - ) + original_scale_inverses = ( + A._rowwise_scale_inv, + A._columnwise_scale_inv, + B._rowwise_scale_inv, + B._columnwise_scale_inv, + ) + + if layout[0] == "T": + A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv) else: - if not out.is_contiguous(): - raise ValueError("Output tensor is not contiguous.") + A._columnwise_scale_inv = tex.columnwise_swizzle( + A._columnwise_data, A._columnwise_scale_inv + ) - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = torch.empty_like(out, dtype=bias_dtype) + if layout[1] == "N": + B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv) else: - gelu_input = empty_tensor - bias_dtype = TE_DType[bias_dtype] + B._columnwise_scale_inv = tex.columnwise_swizzle( + B._columnwise_data, B._columnwise_scale_inv + ) - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + return original_scale_inverses - args = ( - A, - A_scale_inv, - A_fp8_tensor, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor, - B_dtype, - False, # transb - out, - empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index], - out_dtype, - empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index], - bias if use_bias else empty_tensor, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - ) - fn = torch.ops.tex_ts.te_gemm_ts - if ub_algo is not None: - assert ub is not None, "ub object is None!" - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.AG, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.RS, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: - fn = ub.atomic_gemm_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: - fn = ub.atomic_gemm_overlap_rs - assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: - fn = ub.atomic_gemm_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) - - return out, gelu_input - - -def gemm( + +def reset_swizzled_inputs(A, B, scale_inverses): + """Reset the swizzled scale inverses after GEMM.""" + if scale_inverses is not None: + ( + A._rowwise_scale_inv, + A._columnwise_scale_inv, + B._rowwise_scale_inv, + B._columnwise_scale_inv, + ) = scale_inverses + + +def general_gemm( A: torch.Tensor, B: torch.Tensor, - dtype: torch.dtype, workspace: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, gelu: bool = False, - gelu_input: Optional[torch.Tensor] = None, - grad: bool = False, + gelu_in: torch.Tensor = None, accumulate: bool = False, layout: str = "TN", out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - use_bias: bool = False, - ub_algo: tex.CommOverlapAlgo = None, + use_split_accumulator: bool = False, + grad: bool = False, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - extra_output_tensor: torch.Tensor = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Non FP8 GEMM.""" + ub_type: tex.CommOverlapType = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """GEMM supporting fp8 inputs.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - empty_tensor = _empty_tensor() - fp8_index = -1 # dummy index - - if out is None: - out = torch.empty( - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - dtype=dtype, - device="cuda", + # assert quantization_params is None, "FP8 output not supported yet" + + if ub_type is not None: + assert ub is not None, ( + f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" + + "a valid `ub` communicator object." ) - else: - if not out.is_contiguous(): - raise ValueError("Output tensor is not contiguous.") - if gelu and not grad: - gelu_input = torch.empty_like(out, dtype=dtype) - elif not gelu: - gelu_input = empty_tensor + if ub is not None: + assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument." + if ub_type == tex.CommOverlapType.RS: + if not (bulk_overlap and not ub.is_fp8_ubuf()): + assert extra_output is not None, "GEMM+RS overlap requires extra output tensor." - if grad and use_bias: - grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda") - else: - grad_bias = empty_tensor - - bias = bias if use_bias else empty_tensor + if out is not None: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") - assert ( - A.dtype == dtype and B.dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out.dtype] - if use_bias: - bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] - else: - bias_dtype = output_dtype + # Use bfloat16 as default bias_dtype + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] args = ( A, - empty_tensor, - fp8_index, - input_dtype, - transa, + transa, # transa B, - empty_tensor, - fp8_index, - input_dtype, - transb, + transb, # transb out, - empty_tensor, # out_scale - output_dtype, - empty_tensor, # out_amax - grad_bias if grad else bias, + quantization_params, + TE_DType[out_dtype] if out_dtype is not None else None, + bias, bias_dtype, - gelu_input, - grad, + gelu, + gelu_in, + grad, # grad workspace, workspace.shape[0], accumulate, - False, # use_split_accumulator + use_split_accumulator, ) - fn = torch.ops.tex_ts.te_gemm_ts - if ub_algo is not None: - assert ub is not None, "ub object is None!" - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - args = tuple(args + (tex.CommOverlapType.AG, empty_tensor)) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - args = tuple(args + (tex.CommOverlapType.RS, empty_tensor)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - False, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) - - return out, grad_bias, gelu_input - - -def grouped_gemm( + kwargs = { + "comm_overlap": ub, + "comm_type": ub_type, + "extra_output": extra_output, + "bulk_overlap": bulk_overlap, + } + + original_scale_inverses = swizzle_inputs(A, B, layout) + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + reset_swizzled_inputs(A, B, original_scale_inverses) + + return out, bias_grad, gelu_input, extra_output + + +def general_grouped_gemm( A: List[torch.Tensor], B: List[torch.Tensor], out: List[torch.Tensor], - dtype: torch.dtype, + out_dtype: torch.dtype, workspaces: List[torch.Tensor], + layout: str = "TN", + m_splits: Optional[List[int]] = None, gelu: bool = False, - gelu_input: Optional[List[torch.Tensor]] = None, - grad: bool = False, + grad=False, accumulate: bool = False, - layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, + single_output=False, ) -> Tuple[List[torch.Tensor], ...]: - """Non FP8 Grouped GEMM.""" + """ + TN layout Grouped GEMM with fp8 inputs. + """ + num_gemms = len(A) - assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - num_gemms = len(A) + + # assert [a.is_contiguous() for a in A] + # assert [b.is_contiguous() for b in B] + + if isinstance(A[0], Float8TensorBase): + for a, b in zip(A, B): + assert_dim_for_fp8_exec(a._data) + assert_dim_for_fp8_exec(b._data) + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms - if gelu and not grad: - gelu_input = [ - torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out - ] - elif not gelu: - gelu_input = empty_tensors + # Use bfloat16 as default bias_dtype + gelu_input = empty_tensors + out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype + sm_count = get_sm_count() if grad and use_bias: grad_bias = [ torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) ] else: grad_bias = empty_tensors - bias = bias if use_bias else empty_tensors - - assert ( - A[0].dtype == dtype and B[0].dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out[0].dtype] if use_bias: bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype] else: - bias_dtype = output_dtype + bias_dtype = TE_DType[torch.bfloat16] - torch.ops.tex_ts.te_grouped_gemm_ts( + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] # this should differ with respect to single output + + bias = tex.te_general_grouped_gemm( A, - empty_tensor, - 0, # A_offset - input_dtype, transa, B, - empty_tensor, - 0, # B_offset - input_dtype, transb, out, - 0, # out_offset - empty_tensor, # out_scale - output_dtype, - empty_tensor, # out_amax + out_dtype, + m_splits, grad_bias if grad else bias, bias_dtype, - gelu_input, # gelu_input - grad, + single_output, + gelu_input, # this is pre_gelu_out + grad, # grad workspaces, workspaces[0].shape[0], accumulate, - False, # use_split_accumulator + use_split_accumulator, + sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))), ) - return out, grad_bias, gelu_input - - -def fp8_grouped_gemm( - A: List[torch.Tensor], - A_scale_inv: List[torch.Tensor], - A_fp8_tensor_offset: int, - A_dtype: tex.DType, - B: List[torch.Tensor], - B_scale_inv: torch.Tensor, - B_fp8_tensor_offset: int, - B_dtype: tex.DType, - out: List[torch.Tensor], - out_dtype: torch.dtype, - workspaces: List[torch.Tensor], - m_splits: Optional[List[int]] = None, - out_offset: Optional[int] = None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - gelu: bool = False, - accumulate: bool = False, - bias: Optional[List[torch.Tensor]] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> Tuple[List[torch.Tensor], ...]: - """ - TN layout Grouped GEMM with fp8 inputs. - Input requirements: - 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. - This is used for the calculation of output (fwd) and dgrad (bwd). - 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the - calculation of wgrad. - """ - num_gemms = len(A) - if num_gemms > 1 and len(A_scale_inv) == num_gemms: - assert len(out) == 1 and m_splits is not None - elif num_gemms > 1 and len(A_scale_inv) == 1: - assert len(out) == num_gemms - elif num_gemms == 1: - assert len(A_scale_inv) == 1 and len(out) == 1 - else: - raise ValueError("Invalid input combinations of A_scale_inv and out.") - - empty_tensor = _empty_tensor() - empty_tensors = [empty_tensor] * num_gemms - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_offset is not None - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a) - assert_dim_for_fp8_exec(b) - assert A[0].dtype == torch.uint8 - assert B[0].dtype == torch.uint8 - - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - bias_dtype = TE_DType[bias_dtype] - gelu_input = empty_tensors - out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - - if len(A_scale_inv) == 1: - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv[0], - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - else: - if gelu: - gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] - - torch.ops.tex_ts.te_grouped_gemm_single_output_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - m_splits, - out[0], - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - - return out, gelu_input + return out, bias, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py deleted file mode 100644 index f997a8a536..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for normalization extensions""" -from typing import Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - - -__all__ = [ - "layernorm_fwd_fp8", - "layernorm_fwd_fp8_inf", - "layernorm_fwd_inf", - "rmsnorm_fwd_fp8", - "rmsnorm_fwd_fp8_inf", - "rmsnorm_fwd_inf", -] - - -def layernorm_fwd_fp8( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - ln_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """LayerNorm with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - if ln_out is not None: - return tex.layernorm_fwd_fp8_noalloc( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - ln_out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - return tex.layernorm_fwd_fp8( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - - -def layernorm_fwd_fp8_inf( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """LayerNorm with FP8 output. - - This version of layernorm_fwd_fp8 is specialized for inference, and returns - only the normalized output. - """ - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - sm_margin, - zero_centered_gamma, - ) - return ret - - -def layernorm_fwd_inf( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - sm_margin: int, - zero_centered_gamma: bool, -) -> torch.Tensor: - """LayerNorm with FP8 output""" - return torch.ops.tex_ts.layernorm_fwd_inf_ts( - inp, - weight, - bias, - eps, - sm_margin, - zero_centered_gamma, - ) - - -def rmsnorm_fwd_fp8( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - rmsnorm_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """RMSNorm with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - if rmsnorm_out is not None: - return tex.rmsnorm_fwd_fp8_noalloc( - inp, - weight, - eps, - fp8_scales["scale"], - rmsnorm_out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - return tex.rmsnorm_fwd_fp8( - inp, - weight, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - - -def rmsnorm_fwd_fp8_inf( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """RMSNorm with FP8 output. - - This version of rmsnorm_fwd_fp8 is specialized for inference, and returns - only the normalized output. - """ - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( - inp, - weight, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - sm_margin, - zero_centered_gamma, - ) - return ret - - -def rmsnorm_fwd_inf( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - sm_margin: int, - zero_centered_gamma: bool, -) -> torch.Tensor: - """RMSNorm with FP8 output""" - return torch.ops.tex_ts.rmsnorm_fwd_inf_ts( - inp, - weight, - eps, - sm_margin, - zero_centered_gamma, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py deleted file mode 100644 index cf704d06ee..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/padding.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for transpose extensions""" -from typing import List, Tuple, Union -import torch -import transformer_engine_torch as tex - - -__all__ = [ - "multi_padding_fused", -] - - -def multi_padding_fused( - inp: torch.Tensor, - row_list: List[int], - padded_row_list: List[int], - out: torch.Tensor, -) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]: - """Padding""" - - tex.fused_multi_row_padding( - inp, - out, - row_list, - padded_row_list, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py deleted file mode 100644 index 77bf0019af..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for transpose extensions""" -from typing import List, Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex -from ..constants import TE_DType -from ._common import canonicalize_fp8_scales, empty_tensor - - -__all__ = [ - "fp8_cast_transpose_fused", - "fp8_cast_transpose_bgrad_fused", - "fp8_cast_transpose_bgrad_dgelu_fused", - "fp8_dswiglu_cast_transpose_fused", - "fp8_multi_cast_transpose_fused", - "fp8_transpose_bgrad_fused", -] - - -def fp8_cast_transpose_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - cast_out: Optional[torch.Tensor] = None, - transpose_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - noop_flag: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Cast + Transpose with FP8 output""" - - # Allocate outputs if needed - if transpose_out is None: - transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) - if cast_out is None: - cast_out = torch.empty_like(inp, dtype=torch.uint8) - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Construct no-op flag if needed - if noop_flag is None: - noop_flag = empty_tensor() - - # Launch kernel if needed - if inp.nelement() > 0: - tex.fused_cast_transpose_noop( - inp, - noop_flag, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - cast_out, - transpose_out, - otype, - **fp8_scales_offsets, - ) - - return cast_out, transpose_out - - -def fp8_cast_transpose_bgrad_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Cast + Transpose + BGRAD with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_cast_transpose_bgrad( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_transpose_bgrad_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - grad_bias_type: torch.dtype, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transpose + BGRAD with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_fp8_transpose_bgrad( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - TE_DType[grad_bias_type], - **fp8_scales_offsets, - ) - - -def fp8_cast_transpose_bgrad_dgelu_fused( - grad_output: torch.Tensor, - gelu_input: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Cast + Transpose + BGRAD + DGELU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_cast_transpose_bgrad_dgelu( - grad_output, - gelu_input, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_dswiglu_cast_transpose_fused( - grad_output: torch.Tensor, - inp: torch.Tensor, - *, - grad_input: torch.Tensor, - grad_input_transpose: torch.Tensor, - otype: tex.DType, - fp8_meta: Optional[tex.FP8TensorMeta] = None, - fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> None: - """Fused SwiGLU backward + FP8 cast + FP8 transpose""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta, - fp8_meta_index=fp8_meta_index, - ) - - # Launch kernel - return tex.fused_dswiglu_cast_transpose( - grad_output, - inp, - grad_input, - grad_input_transpose, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_multi_cast_transpose_fused( - input_list: List[torch.Tensor], - fp8_meta_tensor: tex.FP8TensorMeta, - scale_indices: List[int], - amax_indices: List[int], - scale_inv_indices: List[int], - otype: tex.DType, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - """Cast + Transpose with FP8 output""" - - return tex.fused_multi_cast_transpose_alloc( - input_list, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, - scale_indices, - amax_indices, - scale_inv_indices, - otype, - ) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 2c8736ee09..33de562a89 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -9,13 +9,27 @@ import torch -from .float8_tensor import Float8Tensor +from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False +def set_offloading_param(tensor, param_name, value): + """Set the type of the offloading needed for a tensor.""" + assert param_name in ["weight_offloading", "activation_offloading"] + if tensor is None: + return + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + setattr(tensor, param_name, value) + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + setattr(tensor, param_name, value) + + def is_cpu_offload_enabled() -> bool: """Check if CPU offloading is currently enabled.""" return CPUOffloadEnabled @@ -258,6 +272,7 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs): else: # will be offloaded together after group commit self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag def tensor_pop(self, tensor_tag, **kwargs): @@ -366,6 +381,7 @@ def bulk_offload_group(self, group_to_offload): if self.tensor_need_offloading_checker(tensor_on_device): state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state + tensor_on_device.data = torch.Tensor() # Force to release memory def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index eb97dc36eb..5775fe381d 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -6,7 +6,33 @@ #include "common.h" +#include "c10/util/ArrayRef.h" +#include "pybind.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine::pytorch { + +std::vector getTensorShape(at::Tensor t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} + +std::unique_ptr convert_quantizer(py::handle quantizer) { + init_extension(); + if (quantizer.is_none()) { + return std::make_unique(quantizer); + } + for (auto [_check_type, check_quantizer_type, _create_tensor, create_quantizer] : + detail::custom_types_converters) { + if (check_quantizer_type(quantizer.ptr())) { + return create_quantizer(quantizer); + } + } + + NVTE_ERROR("Unexpected type for quantizer"); +} transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe) { @@ -17,6 +43,34 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, return transformer_engine::DType::kFloat8E5M2; } +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { + NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + for (auto [check_type, check_quantizer_type, create_tensor, _] : + detail::custom_types_converters) { + if (check_type(tensor.ptr())) { + NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()), + "Unexpected quantization params type."); + auto x = create_tensor(tensor, my_quantizer.get()); + return x; + } + } + + // Regular pyTorch tensor + at::Tensor torch_tensor = tensor.cast(); + + // #TODO (pgadzinski) - needed in attention for non-contiguous tensors. + //if (!torch_tensor.is_contiguous()) { + // torch_tensor = torch_tensor.contiguous(); + //} + auto ret = TensorWrapper(my_quantizer->get_scaling_mode()); + ret.set_rowwise_data(torch_tensor.data_ptr(), + GetTransformerEngineDType(torch_tensor.scalar_type()), + getTensorShape(torch_tensor)); + my_quantizer->set_quantization_params(&ret); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) { return transformer_engine::TensorWrapper(data_ptr, shape, type); @@ -30,48 +84,95 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); std::vector shape; - for (auto s : tensor.sizes()) { shape.push_back(s); } return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, - void* scale_inv_ptr) { - return transformer_engine::TensorWrapper( - data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), reinterpret_cast(scale_inv_ptr)); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape, + const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; } transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, - at::Tensor scale_inv) { + at::Tensor scale_inv, + NVTEScalingMode scaling_mode) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + auto tensor_shape = getTensorShape(tensor); + auto scale_inv_shape = getTensorShape(scale_inv); + NVTE_CHECK(amax.scalar_type() == at::kFloat); NVTE_CHECK(scale.scalar_type() == at::kFloat); NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); - return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + return makeTransformerEngineTensor(tensor.data_ptr(), tensor_shape, dtype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr(), scale_inv_shape, + scaling_mode); } -size_t product(const std::vector& shape) { - size_t ret = 1; +template +T product(const std::vector& shape) { + T ret = 1; for (auto s : shape) { ret *= s; } return ret; } +template size_t product(const std::vector& shape); +template int64_t product(const std::vector& shape); + +size_t product(const NVTEShape& shape, size_t begin, size_t end) { + NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, + " in a shape with ", shape.ndim, " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape.data[i]; + } + return ret; +} + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape) { + std::vector shape; + for (size_t i = 0; i < nvte_shape.ndim; i++) { + shape.push_back(nvte_shape.data[i]); + } + return shape; +} + at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros) { std::vector shape_int64(shape.begin(), shape.end()); @@ -121,3 +222,14 @@ void* getDataPtr(at::Tensor tensor, int offset) { } return dptr; } + +std::vector convertShape(const NVTEShape& shape) { + return std::vector(shape.data, shape.data + shape.ndim); +} + +int roundup(const int value, const int multiple) { + assert(multiple > 0); + return ((value + multiple - 1) / multiple) * multiple; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94e1f7569a..40245cf2d9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -33,23 +33,22 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include -#include -#include #include #include +#include "c10/util/ArrayRef.h" #include "common/util/logging.h" -namespace transformer_engine { +namespace transformer_engine::pytorch { // Each tensor here is shape (N, ) holding all scaling // data for a single FP8 block, e.g. LayerNormLinear @@ -85,7 +84,76 @@ enum FP8BwdTensors { GRAD_INPUT3 = 5 }; -} // namespace transformer_engine +class Quantizer { + public: + virtual NVTEScalingMode get_scaling_mode() const = 0; + + virtual void set_quantization_params(TensorWrapper* tensor) const = 0; + + virtual std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const = 0; + + virtual ~Quantizer() = default; + + bool rowwise_usage = true; + bool columnwise_usage = true; + bool internal = false; + py::handle quantizer; + + protected: + explicit Quantizer(const py::handle& quantizer); +}; + +class NoneQuantizer : public Quantizer { + public: + explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {} + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override {} + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class Float8Quantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + + explicit Float8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class MXFP8Quantizer : public Quantizer { + public: + DType dtype; + + explicit MXFP8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +std::unique_ptr convert_quantizer(py::handle quantizer); + +std::vector getTensorShape(at::Tensor t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -103,9 +171,11 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { case transformer_engine::DType::kBFloat16: return at::kBFloat16; case transformer_engine::DType::kByte: + return at::kByte; case transformer_engine::DType::kFloat8E4M3: + return at::kFloat8_e4m3fn; case transformer_engine::DType::kFloat8E5M2: - return at::kByte; + return at::kFloat8_e5m2; default: NVTE_ERROR("Invalid type"); } @@ -113,6 +183,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { switch (t) { + case at::kFloat8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case at::kFloat8_e5m2: + return transformer_engine::DType::kFloat8E5M2; case at::kHalf: return transformer_engine::DType::kFloat16; case at::kFloat: @@ -128,6 +202,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { case torch::kInt64: return transformer_engine::DType::kInt64; default: + std::cout << "Type: " << static_cast(t) << std::endl; NVTE_ERROR("Invalid type"); } } @@ -140,11 +215,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const std::vector& shape, const transformer_engine::DType type); -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, - void* scale_inv_ptr); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape = {1}, + const std::vector& columnwise_scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, @@ -152,11 +234,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, - const at::Tensor scale, - at::Tensor scale_inv); +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +template +T product(const std::vector& shape); -size_t product(const std::vector& shape); +size_t product(const NVTEShape& shape, size_t begin, size_t end); + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros); @@ -170,4 +259,54 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); +std::vector convertShape(const NVTEShape& shape); + +int roundup(const int value, const int multiple); + +} // namespace transformer_engine::pytorch + +namespace std { +template +string to_string(const vector& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +// Torch shape -> string +template +string to_string(const c10::ArrayRef& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +inline string to_string(const NVTEShape& s) { + string ret = "["; + for (size_t i = 0; i < s.ndim; ++i) { + ret += to_string(s.data[i]) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} +} // namespace std + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 58527ef6d5..e871228b80 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,7 +10,6 @@ #include #include "common.h" -#include "common/common.h" /*************************************************************************************************** * Permutation @@ -45,93 +44,27 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); -std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); - -std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, - const c10::optional descale_S, const c10::optional descale_O, - const c10::optional descale_dO, const c10::optional descale_dP, - const c10::optional scale_S, const c10::optional scale_dP, - const c10::optional scale_dQKV, c10::optional amax_dP, - c10::optional amax_dQKV); - -std::vector fused_attn_fwd_kvpacked( +std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); -std::vector fused_attn_bwd_kvpacked( +std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV); - -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); - -std::vector fused_attn_bwd( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV); + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); @@ -140,237 +73,146 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); * GEMM **************************************************************************************************/ -void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); +using MaybeTensor = std::optional; void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter); -void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, - bool transb, std::vector D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - std::vector bias, transformer_engine::DType bias_type, - std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); - -void te_grouped_gemm_single_output( - std::vector A, std::vector A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, - std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, - transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); /*************************************************************************************************** * Transpose **************************************************************************************************/ -void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - at::Tensor input_cast, at::Tensor input_transpose, - transformer_engine::DType otype); - -void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, at::Tensor input_cast, - at::Tensor input_transpose, transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, - at::Tensor grad_input_transpose, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, int scale_offset = 0, - int amax_offset = 0, int scale_inv_offset = 0); - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_output_list, - std::vector scale_inv_output_list, - transformer_engine::DType otype); - -std::tuple, std::vector> fused_multi_cast_transpose_alloc( - std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - std::vector scale_indices, std::vector amax_indices, - std::vector scale_inv_indices, transformer_engine::DType otype); - -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); - -void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, - transformer_engine::DType otype); +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype); + +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output = std::nullopt); + +namespace transformer_engine::pytorch { /*************************************************************************************************** * Activations **************************************************************************************************/ -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object gelu(const at::Tensor &input, py::handle quantizer); + +py::object relu(const at::Tensor &input, py::handle quantizer); -at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object geglu(const at::Tensor &input, py::handle quantizer); -at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object qgeglu(const at::Tensor &input, py::handle quantizer); -at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object reglu(const at::Tensor &input, py::handle quantizer); -at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object swiglu(const at::Tensor &input, py::handle quantizer); -at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object qgelu(const at::Tensor &input, py::handle quantizer); -at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object srelu(const at::Tensor &input, py::handle quantizer); -at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +} // namespace transformer_engine::pytorch /*************************************************************************************************** * LayerNorm **************************************************************************************************/ -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); -std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); - -std::vector layernorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, - at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, const int scale_inv_offset = 0); - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + transformer_engine::DType out_dtype, const int sm_margin, const bool zero_centered_gamma); -std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, at::Tensor ln_out, float eps, - const int sm_margin, const bool zero_centered_gamma); - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma); - /*************************************************************************************************** * RMSNorm **************************************************************************************************/ -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); - -std::vector rmsnorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, float eps, at::Tensor scale, - at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma); +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma); + +/*************************************************************************************************** + * Cast + **************************************************************************************************/ -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma); +namespace transformer_engine::pytorch { -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma); +py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, + std::optional noop); + +py::object dequantize(const py::handle &input, transformer_engine::DType otype); + +std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer); + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, + std::optional comm_type = std::nullopt, + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); /*************************************************************************************************** - * Cast + * Cast fusions **************************************************************************************************/ -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype, - const int scale_inv_offset = 0); +} // namespace transformer_engine::pytorch /*************************************************************************************************** * Softmax @@ -405,7 +247,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, - std::vector scale_invs, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin); @@ -518,6 +359,16 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * swizzle + **************************************************************************************************/ + +void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv); + /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -551,151 +402,44 @@ class CommOverlapHelper : torch::CustomClassHolder { }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, int comm_type); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + + ~CommOverlap() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, - bool use_ce = true, bool aggregate = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, bool chunk); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy); - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor B_copy); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + bool aggregate = false); + + ~CommOverlapP2P() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + }; // CommOverlapP2P #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 48832e6994..7ce33ee77b 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -5,272 +5,114 @@ ************************************************************************/ #include "extensions.h" +#include "pybind.h" -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; +namespace transformer_engine::pytorch { - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); - auto output = allocateTorchTensor(M, N, otype); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + input_shape[input_shape.size() - 1] /= shape_divisor; + auto fake_tensor_type = input.scalar_type(); - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - return output; + return out; } -at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; +template +py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); + auto grad_tensor = grad.contiguous(); - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = input.scalar_type(); - auto output = allocateTorchTensor(M, N, otype); + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; + return out; } -at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object gelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object geglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object reglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dqgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_srelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object swiglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} - auto output = allocateTorchTensor(M, N, otype); +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} - nvte_dsrelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} - return output; +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index d9977f01b9..c323e7b6c1 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -8,7 +8,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const bool transpose_output_memory) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(input.size(0) <= freqs.size(0), @@ -66,7 +66,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const bool transpose_output_memory) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(output_grads.size(0) <= freqs.size(0), @@ -122,7 +122,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -174,7 +174,7 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9c9ffdb1a7..f2d1ecf3b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/common.h" #include "common/fused_attn/thd_utils.h" #include "extensions.h" @@ -40,22 +41,27 @@ __global__ void __launch_bounds__(block_size) } // fast zero-fills of tensors -void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { - auto max_tokens = self.size(0); - auto self_2d = self.view({max_tokens, -1}); - auto fcd_size = self_2d.size(1); - TORCH_CHECK(self.is_contiguous(), "input not contiguous"); +void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { + std::vector shape = transformer_engine::pytorch::convertShape(self.shape()); + + auto max_tokens = shape[0]; + auto fcd_size = 1; + for (int i = 1; i <= shape.size(); i++) { + fcd_size *= shape[i]; + } TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); dim3 dim_grid(num_blk_x, num_blk_y); dim3 dim_block(block_size); + // trzeba jakos przekonwertowac DType na scalar_type + at::ScalarType scalar_type = transformer_engine::pytorch::GetATenDType(self.dtype()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_fill", [&]() { + at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "mha_fill", [&]() { mha_fill_kernel<<>>( - self_2d.data_ptr(), static_cast(start_index.data_ptr()), - max_tokens); + static_cast(self.get_rowwise_data().data_ptr), + static_cast(start_index.data_ptr()), max_tokens); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -80,735 +86,48 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe return philox_args; } -// fused attention FWD with packed QKV -std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { - using namespace transformer_engine; - - auto qkv_sizes = QKV.sizes().vec(); - std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; - std::vector q_shape; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - int loc_3 = 0; - switch (layout_group) { - case NVTE_3HD: - loc_3 = qkv_sizes.size() - 3; - break; - case NVTE_H3D: - loc_3 = qkv_sizes.size() - 2; - break; - default: - NVTE_ERROR("Invalid QKV layout group."); - } - for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { - if (it - qkv_shape.begin() != loc_3) { - q_shape.push_back(*it); - } - } - std::vector o_shape{q_shape.begin(), q_shape.end()}; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty(o_shape, options); - - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens, te_cu_seqlens_padded; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); - } - // BF16 or FP16 - te_QKV = - makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); - } - auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); - std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if (cu_seqlens_padded.has_value()) { - auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec(); - std::vector cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(), - cu_seqlens_padded_sizes.end()}; - te_cu_seqlens_padded = - makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // extract random number generator seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; - } - } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed QKV -std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, - const c10::optional descale_S, const c10::optional descale_O, - const c10::optional descale_dO, const c10::optional descale_dP, - const c10::optional scale_S, const c10::optional scale_dP, - const c10::optional scale_dQKV, c10::optional amax_dP, - c10::optional amax_dQKV) { - using namespace transformer_engine; - - auto qkv_sizes = QKV.sizes().vec(); - std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; - std::vector q_shape; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - int loc_3 = 0; - switch (layout_group) { - case NVTE_3HD: - loc_3 = qkv_sizes.size() - 3; - break; - case NVTE_H3D: - loc_3 = qkv_sizes.size() - 2; - break; - default: - NVTE_ERROR("Invalid QKV layout group."); - } - for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { - if (it - qkv_shape.begin() != loc_3) { - q_shape.push_back(*it); - } - } - auto h = q_shape[q_shape.size() - 2]; - - // create output tensor dQKV - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - at::Tensor dQKV = torch::empty_like(QKV, options); - - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQKV.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQKV.fill_(0); - } - // BF16 or FP16 - te_QKV = - makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, nullptr, nullptr, - nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } - - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h), static_cast(max_seqlen), - static_cast(max_seqlen)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); - std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - TensorWrapper te_cu_seqlens = makeTransformerEngineTensor( - cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); - - TensorWrapper te_cu_seqlens_padded; - if (cu_seqlens_padded.has_value()) { - auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec(); - std::vector cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(), - cu_seqlens_padded_sizes.end()}; - te_cu_seqlens_padded = - makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQKV, dBias}; -} - -// fused attention FWD with packed KV -std::vector fused_attn_fwd_kvpacked( +// fused attention FWD with separate Q, K and V tensors +std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; + TensorWrapper te_Q, te_K, te_V, te_O, te_S; - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto kv_sizes = KV.sizes().vec(); - std::vector kv_shape{kv_sizes.begin(), kv_sizes.end()}; - std::vector o_shape{q_shape.begin(), q_shape.end()}; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty(o_shape, options); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); - } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_KV = - makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); - } - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - } - - // extract rng seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], - workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; - } - } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], - workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed KV -std::vector fused_attn_bwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV) { - using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto kv_sizes = KV.sizes().vec(); - std::vector kv_shape{kv_sizes.begin(), kv_sizes.end()}; - std::vector k_shape; - for (auto i : kv_shape) { - if (i != 2) { - k_shape.push_back(i); - } - } - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - - // create output tensors dQ and dKV - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - at::Tensor dQ = torch::empty_like(Q, options); - at::Tensor dKV = torch::empty_like(KV, options); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dKV.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQ.fill_(0); - dKV.fill_(0); - } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_KV = - makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dKV = - makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + auto none = py::none(); + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - } + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); - // convert auxiliary tensors from forward to NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } + // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQ, dKV, dBias}; -} - -// fused attention FWD with separate Q, K and V tensors -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { - using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto k_sizes = K.sizes().vec(); - std::vector k_shape{k_sizes.begin(), k_sizes.end()}; - auto v_sizes = V.sizes().vec(); - std::vector v_shape{v_sizes.begin(), v_sizes.end()}; - - // create output tensor O + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto o_shape = std::vector{q_sizes.begin(), q_sizes.end()}; - o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1]; - std::vector o_shape_tmp{o_shape.begin(), o_shape.end()}; - auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options); + // create output tensor O + + auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; + o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + py::object o_python, s_python; + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; + TensorWrapper te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { @@ -817,55 +136,30 @@ std::vector fused_attn_fwd( auto d = q_shape[q_shape.size() - 1]; if (set_zero && ((h * d) % block_size == 0) && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + te_O.zero_(at::cuda::getCurrentCUDAStream()); } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); + te_O.zero_(at::cuda::getCurrentCUDAStream()); } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = + makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); + te_cu_seqlens_kv = + makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); @@ -875,11 +169,9 @@ std::vector fused_attn_fwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); } // extract rng seed and offset @@ -913,8 +205,8 @@ std::vector fused_attn_fwd( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); + std::vector output_tensors; + output_tensors.push_back(o_python); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors @@ -936,7 +228,7 @@ std::vector fused_attn_fwd( } else { output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); } - output_tensors.push_back(output_tensor); + output_tensors.push_back(py::cast(output_tensor)); tensor->data.dptr = output_tensor.data_ptr(); } @@ -957,45 +249,55 @@ std::vector fused_attn_fwd( } // fused attention BWD with separate Q, K and V -std::vector fused_attn_bwd( +std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV) { + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer) { using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto k_sizes = K.sizes().vec(); - std::vector k_shape{k_sizes.begin(), k_sizes.end()}; - auto v_sizes = V.sizes().vec(); - std::vector v_shape{v_sizes.begin(), v_sizes.end()}; + using namespace transformer_engine::pytorch; + auto none = py::none(); + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); + te_O = makeTransformerEngineTensor(O, none); + te_dO = makeTransformerEngineTensor(dO, none); + // qkv type from the te_Q + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + + py::object s_python, dp_python; + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; auto d_v = v_shape[v_shape.size() - 1]; auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_sizes.begin(), q_sizes.end()}; + std::vector o_shape{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = d_v; - at::Tensor dQ; - at::Tensor dK; - at::Tensor dV; - at::Tensor dQKV, dKV; + at::Tensor dQ, dK, dV, dQKV, dKV; + py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1012,7 +314,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1026,8 +328,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1040,8 +343,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1052,82 +356,41 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - dQ = torch::empty_like(Q, options); - dK = torch::empty_like(K, options); - dV = torch::empty_like(V, options); + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + dK = torch::empty(tmp_shape, options); + tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + dV = torch::empty(tmp_shape, options); break; default: NVTE_ERROR("QKV layout not supported!"); } + std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { dQ.fill_(0); dK.fill_(0); dV.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dK = - makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dV = - makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dK = - makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dV = - makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -1152,11 +415,9 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -1219,7 +480,7 @@ std::vector fused_attn_bwd( // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {dQ, dK, dV, dBias}; + return {py_dQ, py_dK, py_dV, py::cast(dBias)}; } namespace flash_attention { diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp new file mode 100644 index 0000000000..a1fe8bd2b5 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" +#include "transformer_engine/cast.h" + +namespace transformer_engine::pytorch { + +std::vector bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { + auto quantizer = convert_quantizer(py_quantizer); + + auto input_tensor = makeTransformerEngineTensor(input); + + auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); + + std::vector output_shape; + for (auto s : input.sizes()) { + output_shape.emplace_back(static_cast(s)); + } + auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype()); + + // Return immediately if tensors are empty + if (product(output_shape) == 0) { + return {py::cast(dbias.zero_()), out}; + } + + auto dbias_tensor = makeTransformerEngineTensor(dbias); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + void* workspace_data_ptr = nullptr; + if (workspace.shape().ndim > 0) { + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace_data_ptr = workspace_data.data_ptr(); + } + workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); + + // Launch kernel + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + return {py::cast(dbias), out}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 771fa4920a..66dafdaafb 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -4,69 +4,126 @@ * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/cast.h" + +#include "common.h" #include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, + std::optional noop) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = tensor.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = tensor.scalar_type(); + if (!detail::IsFloatingPointType(fake_tensor_type)) { + fake_tensor_type = at::kFloat; + } + + TensorWrapper te_output; + py::object out; + if (output.is_none()) { + DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); + std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); + } else { + out = output; + te_output = makeTransformerEngineTensor(output, quantizer); + } + + TensorWrapper te_noop; + if (noop.has_value()) { + te_noop = makeTransformerEngineTensor(*noop); + } else { + te_noop = TensorWrapper(); + } + + if (te_output.numel() == 0) return out; + nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), + at::cuda::getCurrentCUDAStream()); + + return out; +} + +py::object dequantize(const py::handle& input, transformer_engine::DType otype) { + init_extension(); -at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset, const int amax_offset, const int scale_inv_offset) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; + const auto none = py::none(); - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + const auto& input_tensor = makeTransformerEngineTensor(input, none); - if (input.numel() == 0) return output; + NoneQuantizer q(none); - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + const auto& shape = convertShape(input_tensor.shape()); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, - scale_dptr, scale_inv_dptr); + auto [out_tensor, out] = q.create_tensor(shape, otype); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); - return output; + return out; } -void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { - using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); +template +std::vector dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + + auto grad_tensor = makeTransformerEngineTensor(grad_output); + + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); + auto act_input_tensor = makeTransformerEngineTensor(act_input); + + const auto& shape = convertShape(grad_tensor.shape()); + auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto dbias_tensor = makeTransformerEngineTensor(grad_bias); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + // Launch kernel + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); - return; + return {py::cast(grad_bias), dact}; } -at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype, - const int scale_inv_offset) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; +std::vector dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); +std::vector dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - getDataPtr(scale_inv, scale_inv_offset)); - auto output_cu = makeTransformerEngineTensor(output); +std::vector dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +std::vector dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - return output; +std::vector dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 6b54f2de69..30126651ce 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../extensions.h" +#include "transformer_engine/transformer_engine.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 @@ -14,50 +15,6 @@ using namespace std::placeholders; namespace te = transformer_engine; -#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \ - B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \ - bias_type, pre_gelu_out, workspace) \ - A = A.contiguous(); \ - void *A_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(A_type)) { \ - assert(A_scale_inv.numel()); \ - A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ - } \ - auto A_ = makeTransformerEngineTensor( \ - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ - nullptr, nullptr, A_scale_inv_ptr); \ - B = B.contiguous(); \ - void *B_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(B_type)) { \ - assert(B_scale_inv.numel()); \ - B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ - } \ - auto B_ = makeTransformerEngineTensor( \ - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ - nullptr, nullptr, B_scale_inv_ptr); \ - void *D_amax_ptr = nullptr; \ - void *D_scale_ptr = nullptr; \ - if (te::is_fp8_dtype(D_type)) { \ - assert(D_amax.numel()); \ - D_amax_ptr = D_amax.data_ptr(); \ - assert(D_scale.numel()); \ - D_scale_ptr = D_scale.data_ptr(); \ - } \ - auto D_ = makeTransformerEngineTensor( \ - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ - D_amax_ptr, D_scale_ptr, nullptr); \ - auto bias_ = makeTransformerEngineTensor( \ - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ - const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ - ? std::vector{static_cast(pre_gelu_out.size(0))} \ - : std::vector{static_cast(pre_gelu_out.size(0)), \ - static_cast(pre_gelu_out.size(1))}; \ - auto pre_gelu_out_ = makeTransformerEngineTensor( \ - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ - auto workspace_ = makeTransformerEngineTensor( \ - workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ - te::DType::kByte); - /*************************************************************************************************** * CommOverlapHelper **************************************************************************************************/ @@ -185,145 +142,92 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, - helper->numnodes, tp_size, + int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), + helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, + helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Bulk GEMM + COMM -** This function assumes the communication input is pre-copied to _ubuf -*/ -std::vector CommOverlap::bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - te::CommOverlapType comm_type, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, - grad, accumulate, use_split_accumulator, comm_type, rs_out_, - stream_main); - - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - if (comm_type == te::CommOverlapType::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = - (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - auto output_tensor = - torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); - - return {D, output_tensor}; -} // CommOverlap::bulk_overlap - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - te::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs +void CommOverlap::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + _ubuf_scale_inv_initialized = true; +} /* ** Helper function to copy input to _ubuf */ -void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { +void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type == te::CommOverlapType::AG) { - if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + if (local_chunk) { + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); } + // Copy either row or columnwise data into the communication buffer's columnwise data + // NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with + // the Userbuffers communicator. at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), + input_tensor.numel() * input_tensor.element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); } -torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { +py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; } /*************************************************************************************************** @@ -333,148 +237,85 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm, bool use_ce, bool aggregate) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool aggregate) : te::CommOverlapP2PBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Split AllGather + AtomicGEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, B_copy_, stream_main); -} // atomic_gemm_overlap_ag + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm, aggregate) {} -/* -** Split AllGather + GEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::split_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - B_copy_, stream_main); -} // split_overlap_ag - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, rs_out_, stream_main); -} - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - rs_out_, stream_main); +void CommOverlapP2P::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]); } /* ** Copy input to _ubufs[0] */ -void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { +void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { + if (local_chunk) { // Copy input to the target ubuf chunk by rank offset - if (input.numel() != (int64_t)_ubufs[0].numel() || - input.element_size() != (int64_t)_ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } } -torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); +py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; + char *ubuf_wt_ptr = reinterpret_cast(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 250c9993fb..b044c9f604 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -4,74 +4,272 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + +#include +#include + +#include "../common.h" +#include "common.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" -void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine; - if (A.numel() == 0 || B.numel() == 0) { - if (D.numel() != 0 && !accumulate) D.zero_(); - if (bias.numel() != 0 && grad) { - if (B.numel() == 0) { - bias.zero_(); - } else { - bias.copy_(B.sum(0)); +namespace { + +void* get_data_ptr(MaybeTensor tensor) { + if (tensor.has_value()) return tensor->data_ptr(); + return nullptr; +} + +size_t get_size(MaybeTensor tensor, int dim) { + if (tensor.has_value()) return static_cast(tensor->size(dim)); + return 0; +} + +} // namespace + +namespace transformer_engine::pytorch { + +namespace detail { + +std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { + // Flatten outer dims to get 2D matrices + const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); + const size_t A1 = A_shape.data[A_shape.ndim - 1]; + const size_t B0 = product(B_shape, 0, B_shape.ndim - 1); + const size_t B1 = B_shape.data[B_shape.ndim - 1]; + + // Check matrix dims + NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", + A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); + + // Construct output dims + std::vector ret; + if (transb) { + ret.emplace_back(B1); + } else { + // Unflatten B0 + for (size_t i = 0; i < B_shape.ndim - 1; ++i) { + ret.emplace_back(B_shape.data[i]); + } + } + if (transa) { + ret.emplace_back(A0); + } else { + ret.emplace_back(A1); + } + return ret; +} + +bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { + if (expected.size() != actual.ndim) return false; + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != actual.data[i]) return false; + } + return true; +} + +} // namespace detail + +std::pair createOutputTensor(const std::vector& shape, + DType dtype, py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore* comm_overlap, + std::optional comm_type, MaybeTensor extra_output, + bool bulk_overlap) { + // Input tensors + NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); + NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); + auto none = py::none(); + TensorWrapper A_tensor = makeTransformerEngineTensor(A, none); + TensorWrapper B_tensor = makeTransformerEngineTensor(B, none); + + // Check tensor dimensions + const auto& A_shape = A_tensor.shape(); + const auto& B_shape = B_tensor.shape(); + const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); + NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); + + // Output tensor + TensorWrapper D_tensor; + if (D.is_none()) { + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + } else { + D_tensor = makeTransformerEngineTensor(D, quantizer); + NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), + "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", + std::to_string(D_tensor.shape()), ")"); + if (out_dtype) { + NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", + static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); + } + } + + // Bias tensor + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + if (grad) { + auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); + bias_tensor = makeTransformerEngineTensor(*bias_grad); + } else { + if (!bias->is_contiguous()) { + bias = bias->contiguous(); } + bias_tensor = makeTransformerEngineTensor(*bias); } - if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_(); - return; } - A = A.contiguous(); - B = B.contiguous(); + // Activation input tensor + MaybeTensor pre_gelu_out = std::nullopt; + DType gelu_type = bias_type; + if (gelu) { + if (!grad) { + auto dtype = GetATenDType(gelu_type); + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + std::vector torch_shape; + for (auto v : D_shape) { + torch_shape.push_back(v); + } + pre_gelu_out = at::empty(torch_shape, opts); + } else { + if (gelu_in.has_value()) { + pre_gelu_out = *gelu_in; + } + } + } + const auto gelu_shape = gelu ? D_shape : std::vector{0}; - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr()); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr()); - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = - makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + auto te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor( - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + // Workspace auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + auto main_stream = at::cuda::getCurrentCUDAStream(); + if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { + if (comm_overlap) { + // Prepare extra output tensor + TensorWrapper extra_output_tensor; + if (extra_output.has_value()) { + extra_output_tensor = makeTransformerEngineTensor(*extra_output); + } else { + extra_output_tensor = + makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + } + + // Direct GEMM call to the correct overlap + if (bulk_overlap) { + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, comm_type.value(), extra_output_tensor, + main_stream); + } else if (comm_type.value() == CommOverlapType::AG) { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } else { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } + } else { + // Launch GEMM + nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, num_math_sms, main_stream); + } + } else { + if (D_tensor.numel() != 0 && !accumulate) { + D_tensor.zero_(main_stream); + } + if (bias.has_value()) { + if (bias->numel() != 0 && grad) { + bias_grad->zero_(); + } + } + } + + // Pack outputs + std::vector out; + out.emplace_back(std::move(D)); + out.emplace_back(py::cast(bias_grad)); + if (gelu && !grad) { + out.emplace_back(py::cast(*pre_gelu_out)); + } else { + out.emplace_back(py::none()); + } + if (extra_output.has_value()) { + out.emplace_back(py::cast(extra_output)); + } else { + out.emplace_back(py::none()); + } + return out; } +} // namespace transformer_engine::pytorch + void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + // TODO: Handle scaling modes + NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; + NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr()); + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); auto te_B = makeTransformerEngineTensor( B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr()); + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); + // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor( D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, D_amax.data_ptr(), D_scale.data_ptr(), nullptr); @@ -95,134 +293,108 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); } -void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, - bool transb, std::vector D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - std::vector bias, transformer_engine::DType bias_type, - std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; + using namespace transformer_engine::pytorch; + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector, te_workspace_vector; + std::vector wrappers; + std::vector D_vectors; + + auto none = py::none(); + + std::vector single_output_begins; + std::vector single_output_ends; + int slicing_dim; + if (single_output && D == std::nullopt) { + NVTE_ERROR("not implemented, D should be allocated for single output case."); + } + + void* output_data_ptr; + if (single_output) { + output_data_ptr = (*D)[0].data_ptr(); + } + for (size_t i = 0; i < A.size(); i++) { - if (A[i].numel() == 0 || B[i].numel() == 0) { - if (D[i].numel() != 0 && !accumulate) D[i].zero_(); + auto te_A = makeTransformerEngineTensor(A[i], none); + auto te_B = makeTransformerEngineTensor(B[i], none); + + // if there is single output + at::Tensor out_tensor; + auto size_t_shape = + pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); + std::vector D_shape; + for (size_t t : size_t_shape) { + D_shape.push_back(t); + } + auto dtype = GetATenDType(D_type); + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + if (single_output) { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + char* char_ptr = reinterpret_cast(output_data_ptr); + char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size(); + output_data_ptr = reinterpret_cast(char_ptr); + D_vectors.emplace_back(out_tensor); + } else { + if (D == std::nullopt) { + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + out_tensor = at::empty(D_shape, opts); + D_vectors.emplace_back(out_tensor); + } else { + out_tensor = (*D)[i]; + } + } + + if (te_A.numel() == 0 || te_B.numel() == 0) { + if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); if (bias[i].numel() != 0 && grad) { - if (B[i].numel() == 0) { - bias[i].zero_(); - } else { - bias[i].copy_(B[i].sum(0)); - } + bias[i].zero_(); } if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); continue; } - NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); - NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); - NVTE_CHECK(D[i].is_contiguous(), "D[", i, "] must be contiguous."); - - te_A.emplace_back(make_tensor( - A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, - A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i))); - te_B.emplace_back(make_tensor( - B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, - B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); - te_D.emplace_back(make_tensor( - D[i].data_ptr(), {static_cast(D[i].size(0)), static_cast(D[i].size(1))}, - D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); - te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, - bias_type, nullptr, nullptr, nullptr)); + auto te_D = makeTransformerEngineTensor(out_tensor); + auto te_bias = makeTransformerEngineTensor(bias[i]); + auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out[i].size(0))} - : std::vector{static_cast(pre_gelu_out[i].size(0)), - static_cast(pre_gelu_out[i].size(1))}; - te_pre_gelu_out.emplace_back(make_tensor( - pre_gelu_out[i].data_ptr(), gelu_shape, - GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - } - for (size_t i = 0; i < workspace.size(); i++) { - te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, - nullptr, nullptr, nullptr)); - } + ? std::vector{static_cast(te_pre_gelu_out.size(0))} + : std::vector{static_cast(te_pre_gelu_out.size(0)), + static_cast(te_pre_gelu_out.size(1))}; - // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, - te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); -} + DType gelu_type = bias_type; + te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); -void te_grouped_gemm_single_output( - std::vector A, std::vector A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, - std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, - transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine; - std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); - void* d_i_ptr = reinterpret_cast(D.data_ptr()); - for (size_t i = 0; i < A.size(); i++) { - if (m_splits[i] == 0) continue; - NVTE_CHECK(A[i].data_ptr() != nullptr, "A[", i, "] must not be nullptr."); - NVTE_CHECK(B[i].data_ptr() != nullptr, "B[", i, "] must not be nullptr."); - NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); - NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); - te_A.emplace_back(make_tensor( - A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, - A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); - te_B.emplace_back(make_tensor( - B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, - B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); - te_D.emplace_back(make_tensor( - d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, - getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); - te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, - bias_type, nullptr, nullptr, nullptr)); + te_A_vector.emplace_back(te_A.data()); + te_B_vector.emplace_back(te_B.data()); + te_D_vector.emplace_back(te_D.data()); + te_bias_vector.emplace_back(te_bias.data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out[i].size(0))} - : std::vector{static_cast(pre_gelu_out[i].size(0)), - static_cast(pre_gelu_out[i].size(1))}; - te_pre_gelu_out.emplace_back(make_tensor( - pre_gelu_out[i].data_ptr(), gelu_shape, - GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - // Move the D pointer to the next split. - char* char_ptr = reinterpret_cast(d_i_ptr); - char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); - d_i_ptr = reinterpret_cast(char_ptr); + wrappers.emplace_back(std::move(te_A)); + wrappers.emplace_back(std::move(te_B)); + wrappers.emplace_back(std::move(te_D)); + wrappers.emplace_back(std::move(te_bias)); + wrappers.emplace_back(std::move(te_pre_gelu_out)); } for (size_t i = 0; i < workspace.size(); i++) { - te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, - nullptr, nullptr, nullptr)); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); + te_workspace_vector.emplace_back(wsp.data()); + wrappers.emplace_back(std::move(wsp)); } - // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, - te_workspace.data(), accumulate, use_split_accumulator, + nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), + te_A_vector.size(), transa, transb, grad, + te_workspace_vector.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); + return bias; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 2124b551fd..66ad03381c 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -6,10 +6,29 @@ #include "extensions.h" -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, +namespace transformer_engine::pytorch { +std::pair createOutputTensor(const NVTEShape &shape, DType dtype, + py::handle quantizer) { + std::vector shape_vec; + for (int i = 0; i < shape.ndim; i++) { + size_t t = shape.data[i]; + shape_vec.push_back(t); + } + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape_vec, dtype); +} +std::pair createOutputTensor(std::vector &shape, DType dtype, + py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} +} // namespace transformer_engine::pytorch + +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; const auto &dz_ = dz.contiguous(); const auto &x_ = x.contiguous(); const auto &mu_ = mu.contiguous(); @@ -47,61 +66,57 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {dx, dgamma, dbeta}; + return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)}; } -std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset) { +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + DType out_dtype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; using namespace transformer_engine; - const auto &input_ = input.contiguous(); + auto none = py::none(); + const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); - return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, - sm_margin, zero_centered_gamma, scale_offset, amax_offset, - scale_inv_offset); -} - -std::vector layernorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, - at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, const int scale_inv_offset) { - using namespace transformer_engine; - - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - const auto &bias_ = bias.contiguous(); + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + bias_tensor = makeTransformerEngineTensor(*bias); + } // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void *scale_dptr = getDataPtr(scale, scale_offset); - void *amax_dptr = getDataPtr(amax, amax_offset); - void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + size_t N = static_cast(input_tensor.size(0)); + size_t H = static_cast(input_tensor.size(1)); + std::vector size = {N, H}; // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input_); - auto gamma_cu = makeTransformerEngineTensor(weight_); - auto beta_cu = makeTransformerEngineTensor(bias_); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, - scale_inv_dptr); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + + TensorWrapper ln_out_tensor; + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + py::object ln_output; + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // Use high precision output from normalization + NoneQuantizer q{none}; + std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype); + } else { + if (ln_out.is_none()) { + std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); + } else { + ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + } + TensorWrapper mu_cu = makeTransformerEngineTensor(mu); + TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes transformer_engine::TensorWrapper workspace; - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, + ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); @@ -111,66 +126,30 @@ std::vector layernorm_fwd_fp8_noalloc( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, + ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {ln_out, mu, rsigma}; -} - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset - -) { - // This is a specialized version of layernorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = - layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, - zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); - return out[0]; -} - -std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - const auto &input_ = input.contiguous(); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - - return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); -} - -std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, at::Tensor ln_out, float eps, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine; + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + TensorWrapper cast_out_tensor; + if (ln_out.is_none()) { + std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); + } else { + cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } - DType itype = GetTransformerEngineDType(input.scalar_type()); + nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + at::cuda::getCurrentCUDAStream()); + } - return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(), - at::Tensor(), itype, sm_margin, zero_centered_gamma); + return {ln_out, py::cast(mu), py::cast(rsigma)}; } -at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma) { - // This is a specialized version of layernorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = - layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma); - return out[0]; -} - -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; const auto &dz_ = dz.contiguous(); const auto &x_ = x.contiguous(); const auto &rsigma_ = rsigma.contiguous(); @@ -204,57 +183,48 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {dx, dgamma}; + return {py::cast(dx), py::cast(dgamma)}; } -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; using namespace transformer_engine; - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); - return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype, - sm_margin, zero_centered_gamma, scale_offset, amax_offset, - scale_inv_offset); -} - -std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor ln_out, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { - using namespace transformer_engine; + auto none = py::none(); + const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void *scale_dptr = getDataPtr(scale, scale_offset); - void *amax_dptr = getDataPtr(amax, amax_offset); - void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + size_t N = static_cast(input_tensor.shape().data[0]); + size_t H = static_cast(input_tensor.shape().data[1]); // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, - scale_inv_dptr); + std::vector size = {N, H}; + TensorWrapper ln_out_tensor; + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + py::object ln_output; + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // Use high precision output from normalization + NoneQuantizer q{none}; + std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype); + } else { + if (ln_out.is_none()) { + std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); + } else { + ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + } auto rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes transformer_engine::TensorWrapper workspace; - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), + rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); @@ -264,55 +234,22 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), + rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {ln_out, rsigma}; -} + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + TensorWrapper cast_out_tensor; + if (ln_out.is_none()) { + std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); + } else { + cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset) { - // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = - rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin, - zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); - return out[0]; -} - -std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine; - - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - - return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma); -} - -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(), - at::Tensor(), itype, sm_margin, zero_centered_gamma); -} + nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + at::cuda::getCurrentCUDAStream()); + } -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma) { - // This is a specialized version of rmsnorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma); - return out[0]; + return {ln_out, py::none(), py::cast(rsigma)}; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index ca10e4d3c9..b9972af7cb 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -10,6 +10,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), "Number of input row list and padded row list must match."); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index f363e6e7ea..47282da504 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -11,6 +11,7 @@ std::tuple> moe_permute_fwd( at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + using namespace transformer_engine::pytorch; const int num_tokens = input.size(0); int num_cols = input.size(1); const int topK = indices.size(1); @@ -96,6 +97,7 @@ at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dty at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, int64_t topK) { + using namespace transformer_engine::pytorch; int num_cols = input.size(1); // Activations type @@ -129,6 +131,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob) { + using namespace transformer_engine::pytorch; const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e5d8744eef..442837d767 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -4,14 +4,131 @@ * See LICENSE for license information. ************************************************************************/ +#include "pybind.h" + +#include +#include +#include +#include #include #include +#include + +#include "../common.h" #include "../extensions.h" +#include "common.h" + +namespace transformer_engine::pytorch { + +PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8QuantizerClass = nullptr; +PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8QuantizerClass = nullptr; + +void init_float8_extension() { + if (Float8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); + Float8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); + Float8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); + Float8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + NVTE_CHECK(Float8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch Float8 extension."); +} + +void init_mxfp8_extension() { + if (MXFP8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); + MXFP8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); + MXFP8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); + MXFP8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + NVTE_CHECK(MXFP8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch MXFP8 extension."); +} + +void init_extension() { + init_float8_extension(); + init_mxfp8_extension(); +} + +} // namespace transformer_engine::pytorch + #include "common/util/pybind_helper.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), + py::arg("output") = py::none(), py::arg("noop") = py::none()); + m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), + py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, + "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); + m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", + py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), + py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), + py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), + py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), + py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); + m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.", + py::call_guard()); + m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.", + py::call_guard()); + m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + py::arg("quantizer")); + m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu, + "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer")); // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); @@ -42,116 +159,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Other granular functions - m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), - py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("layernorm_bwd", &layernorm_bwd, "LN BWD", py::call_guard()); - m.def("layernorm_fwd", &layernorm_fwd, "LN FWD", py::call_guard()); - m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD", - py::call_guard()); - m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD", py::call_guard()); - m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD", py::call_guard()); - m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD", - py::call_guard()); - m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose", - py::call_guard()); - m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, - "Cast + Transpose with noop option", py::call_guard(), - py::arg("input"), py::arg("noop"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("input_cast"), py::arg("input_transpose"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD", - py::call_guard(), py::arg("grad_output"), py::arg("scale"), - py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD", - py::call_guard(), py::arg("grad_output"), py::arg("scale"), - py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("grad_bias_type"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, - "Fused Cast + Transpose + BGRAD + DGELU", py::call_guard(), - py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, - "Fused SwiGLU backward + FP8 cast + FP8 transpose", - py::call_guard(), py::arg("grad_output"), py::arg("input"), - py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, - "Fused Multi-tensor Cast + Transpose", py::call_guard()); - m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, - "Fused Multi-tensor Cast + Transpose with allocating output tensors", - py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), - py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard(), py::arg("input"), py::arg("scale"), - py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), - py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), - py::arg("scale_inv_offset") = 0); - m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think - m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); - m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed QKV", - py::call_guard()); - m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed QKV", - py::call_guard()); - m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed KV", - py::call_guard()); - m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed KV", - py::call_guard()); + m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), + py::arg("sm_margin"), py::arg("zero_centered_gamma")); + m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), + py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), + py::arg("zero_centered_gamma")); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); + m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", + py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + + m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd", &fused_attn_fwd, - "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V", - py::call_guard()); + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &fused_attn_bwd, - "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V", - py::call_guard()); - m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, - "Transpose with FP8 I/O with noop option.", py::call_guard()); - m.def("gelu", &gelu, "GeLU with FP8 output", py::call_guard()); - m.def("relu", &relu, "ReLU with FP8 output", py::call_guard()); - m.def("geglu", &geglu, "GeGLU with FP8 output", py::call_guard()); - m.def("reglu", ®lu, "ReGLU with FP8 output", py::call_guard()); - m.def("swiglu", &swiglu, "SwiGLU with FP8 output", py::call_guard()); - m.def("qgelu", &qgelu, "QuickGELU with FP8 output", py::call_guard()); - m.def("srelu", &srelu, "Squared ReLU with FP8 output", py::call_guard()); - m.def("dgelu", &dgelu, "Backward of GeLU", py::call_guard()); - m.def("drelu", &drelu, "Backward of ReLU", py::call_guard()); - m.def("dgeglu", &dgeglu, "Backward of GeGLU", py::call_guard()); - m.def("dreglu", &dreglu, "Backward of ReGLU", py::call_guard()); - m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard()); - m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard()); - m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard()); + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); + m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), + py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", py::call_guard()); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", @@ -233,30 +258,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Data structures - py::class_(m, "FP8TensorMeta") + py::class_(m, "FP8TensorMeta") .def(py::init<>()) - .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) - .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) - .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); + .def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale) + .def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv) + .def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history); - py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); + py::enum_(m, "FP8FwdTensors") + .value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT); - py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); + py::enum_(m, "FP8BwdTensors") + .value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) @@ -265,54 +290,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("world_group"), py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); - py::class_(m, "CommOverlap") + py::class_, transformer_engine::CommOverlapBase, + transformer_engine::CommOverlapCore>(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, bool, bool>(), + int, int, int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, - py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, - py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) - .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) - .def("split_overlap_rs", &CommOverlap::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlap::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, - py::call_guard()) - .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) - .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) - .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); + py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, + py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, + py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) + .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlap::set_buffer_params); - py::class_(m, "CommOverlapP2P") + py::class_, + transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( + m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, - transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), + transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, + bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, - py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, - py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, - py::call_guard()) - .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, - py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, - py::call_guard()) - .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) - .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, - py::call_guard()); + py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, + py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, + py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlapP2P::set_buffer_params); } diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp new file mode 100644 index 0000000000..effeb8cb4d --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -0,0 +1,227 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "common.h" +#include "pybind.h" +#include "torch/torch.h" +#include "util.h" + +namespace transformer_engine::pytorch { + +constexpr size_t MXFP8_BLOCK_SIZE = 32; + +Quantizer::Quantizer(const py::handle& quantizer) { + if (quantizer.is_none()) { + this->rowwise_usage = true; + this->columnwise_usage = true; + this->internal = false; + } else { + this->rowwise_usage = quantizer.attr("rowwise_usage").cast(); + this->columnwise_usage = quantizer.attr("columnwise_usage").cast(); + this->internal = quantizer.attr("internal").cast(); + this->quantizer = quantizer; + } +} + +Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; +} + +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + at::TensorOptions opts; + opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + at::Tensor ret; + if (rowwise_data.has_value()) { + ret = std::move(*rowwise_data); + } else { + ret = at::empty(torch_shape, opts); + } + + TensorWrapper tensor; + tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); + return {std::move(tensor), py::cast(ret)}; +} + +void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + opts = opts.dtype(torch::kFloat32); + at::Tensor scale_inv = at::reciprocal(scale); + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + +MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); +} + +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair MXFP8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); + at::TensorOptions opts; + at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, + columnwise_scale_inv; // TODO(pgadzinski) - change + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + auto last_dim = static_cast(torch_shape.back()); + + NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", torch_shape, ")"); + + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(torch_shape, opts); + } + auto sinv0 = roundup(numel / last_dim, 128); + auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + auto sinv1 = roundup(last_dim, 128); + columnwise_data = at::empty(torch_shape, opts); + columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); + tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); + ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index ec75a2a8c6..e8a31da99a 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -9,20 +9,22 @@ #include +#include "common/common.h" #include "extensions.h" -void fused_amax_and_scale_update_after_reduction( - const at::Tensor &amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; size_t num_tensors = amax_histories.size(); std::vector t_amax_histories(num_tensors); std::vector t_scales(num_tensors); - std::vector t_scale_invs(num_tensors); std::vector te_amax_histories(num_tensors); std::vector te_scales(num_tensors); - std::vector te_scale_invs(num_tensors); for (size_t i = 0; i < num_tensors; i++) { t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); auto amax_sizes = amax_histories[i].sizes().vec(); @@ -36,18 +38,11 @@ void fused_amax_and_scale_update_after_reduction( t_scales[i].data.shape = scale_shape; t_scales[i].data.dtype = DType::kFloat32; - t_scale_invs[i].data.dptr = scale_invs[i].data_ptr(); - auto scale_inv_sizes = scale_invs[i].sizes().vec(); - std::vector scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()}; - t_scale_invs[i].data.shape = scale_inv_shape; - t_scale_invs[i].data.dtype = DType::kFloat32; - te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); te_scales[i] = reinterpret_cast(&t_scales[i]); - te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, - te_scale_invs, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, + amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index 93be90c9f3..02f8fcbdf6 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -7,7 +7,7 @@ #include "extensions.h" at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), @@ -38,7 +38,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -65,7 +65,7 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r } at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || @@ -105,7 +105,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -132,7 +132,7 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so } at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || @@ -159,7 +159,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -188,7 +188,7 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, } at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), @@ -220,7 +220,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp new file mode 100644 index 0000000000..316e6515bf --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" +#include "transformer_engine/transformer_engine.h" + +void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (input.scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + return; + } + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = input.get_rowwise_scale_inv(); + } else { + scale_inv = input.get_columnwise_scale_inv(); + } + + auto input_shape = nvte_shape_to_vector(input.shape()); + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + + // Allocate memory for swizzled output. + auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + std::vector scale_inv_shape_int; + for (size_t i = 0; i < scale_inv_shape.size(); ++i) { + scale_inv_shape_int.push_back(static_cast(scale_inv_shape[i])); + } + auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options); + void* scale_inv_dptr = scale_inv.data_ptr; + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, + scale_inv_shape); + } + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } +} + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input), + DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr, + getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return swizzled_scale_inv; +} + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + // Return immediately if tensor is empty + if (scale_inv.numel() == 0) { + return swizzled_scale_inv; + } + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv), + NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 40f76c898c..37fbddcc18 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -4,434 +4,104 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" - -void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - at::Tensor input_cast, at::Tensor input_transpose, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input); - auto output_cast_cu = - makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - auto output_transpose_cu = - makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); +#include - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, at::Tensor input_cast, - at::Tensor input_transpose, transformer_engine::DType otype, - int scale_offset, int amax_offset, int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(input); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - - // Launch kernel - nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), - output_transpose_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); +#include "ATen/core/TensorBody.h" +#include "extensions.h" - // Allocate output tensors - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto grad_output_cast = - allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype) { + using namespace transformer_engine::pytorch; + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + std::vector py_output_objects_list; + std::vector tensor_wrappers; + auto none = py::none(); + + // create TE tensors from input + for (int i = 0; i < input_list.size(); i++) { + auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + const NVTEShape input_shape = input_tensor.shape(); + + transformer_engine::TensorWrapper output_tensor; + + if (output_list == std::nullopt) { + std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + py::object o; + std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype); + py_output_objects_list.push_back(o); + } else { + output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]); + } + if (input_tensor.numel() == 0) continue; - // Return immediately if tensors are empty - if (M == 0 || N == 0) { - return {grad_bias.zero_(), grad_output_cast, grad_output_transpose}; + nvte_tensor_output_list.emplace_back(output_tensor.data()); + nvte_tensor_input_list.emplace_back(input_tensor.data()); + tensor_wrappers.emplace_back(std::move(input_tensor)); + tensor_wrappers.emplace_back(std::move(output_tensor)); } - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor( - grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); - auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor( - grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_transpose}; -} - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto dgelu = allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); - auto dgelu_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); - auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - - return {grad_bias, dgelu, dgelu_transpose}; -} - -void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, - at::Tensor grad_input_transpose, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, int scale_offset, - int amax_offset, int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - auto outer_dim = [](const at::Tensor& tensor) -> size_t { - return tensor.numel() / tensor.size(-1); - }; - const auto M = outer_dim(grad_output); - const auto N = static_cast(grad_output.size(-1)); - - // Check tensor dims - NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", - grad_output.dim()); - NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); - NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, - ", but found ", outer_dim(input)); - NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, - ", but found ", input.size(-1)); - NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", - grad_input.dim()); - NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", - M, ", but found ", outer_dim(grad_input)); - NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", - 2 * N, ", but found ", grad_input.size(-1)); - NVTE_CHECK(grad_input_transpose.dim() == 2, - "Expected grad input transpose tensor to have 2 dims, but found ", - grad_input_transpose.dim()); - NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, - "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", - grad_input_transpose.size(0)); - NVTE_CHECK(grad_input_transpose.size(1) == M, - "Expected grad input tensor to have outer dimension of ", M, ", but found ", - grad_input_transpose.size(1)); - - // Check tensor format - NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); - NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); - NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); - NVTE_CHECK(grad_input_transpose.is_contiguous(), - "Expected grad input transpose tensor to be contiguous"); - NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), - "Expected grad output tensor and input tensor to have same dtype"); - NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, - "Expected grad input tensor to be uint8 buffer"); - NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, - "Expected grad input transpose tensor to be uint8 buffer"); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto dy_cu = makeTransformerEngineTensor(grad_output); - auto x_cu = makeTransformerEngineTensor(input); - auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - - // Launch kernel - nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_multi_cast_transpose_base(std::vector input_list, - std::vector scale_dptr_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_dptr_list, - std::vector scale_inv_dptr_list, - transformer_engine::DType otype) { - using namespace transformer_engine; - - // Extract properties from PyTorch tensors - std::vector input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list; - std::vector> input_shape_list, cast_output_shape_list, - transposed_output_shape_list; - std::vector input_type_list, cast_output_type_list, - transposed_output_type_list; - auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector& dptr_list, - std::vector>& shape_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); + // Check tensor lists + NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), + "Number of input and output tensors must match"); + + // Choose implementation + // Note: Currently only have fused kernel for FP8 cast-transpose + bool with_fused_kernel = true; + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + const auto& tensor = nvte_tensor_output_list[i]; + if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) { + with_fused_kernel = false; + break; } - }; - auto extract_tensor_props = [](at::Tensor& tensor, std::vector& dptr_list, - std::vector>& shape_list, - std::vector& type_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); + if (nvte_tensor_columnwise_data(tensor) == nullptr) { + with_fused_kernel = false; + break; } - type_list.push_back(GetTransformerEngineDType(tensor.scalar_type())); - }; - for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { - extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); - extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, - cast_output_shape_list); - cast_output_type_list.push_back(otype); - extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, - transposed_output_shape_list); - transposed_output_type_list.push_back(otype); } - // Construct TE tensors - std::vector nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - for (size_t i = 0; i < input_dptr_list.size(); ++i) { - if (input_dptr_list[i] == nullptr) continue; - nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], - input_type_list[i], nullptr, nullptr, nullptr)); - nvte_cast_output_list.emplace_back( - make_tensor(cast_output_dptr_list[i], cast_output_shape_list[i], cast_output_type_list[i], - amax_dptr_list[i], scale_dptr_list[i], scale_inv_dptr_list[i])); - nvte_transposed_output_list.emplace_back( - make_tensor(transposed_output_dptr_list[i], transposed_output_shape_list[i], - transposed_output_type_list[i], amax_dptr_list[i], scale_dptr_list[i], - scale_inv_dptr_list[i])); - } - - // Check tensor lists - NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(), - "Number of input and T output tensors must match"); - // Launch TE kernel - nvte_multi_cast_transpose(nvte_input_list.size(), nvte_input_list.data(), - nvte_cast_output_list.data(), nvte_transposed_output_list.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_list, - std::vector scale_inv_list, - transformer_engine::DType otype) { - std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; - for (size_t i = 0; i < scale_list.size(); ++i) { - scale_dptr_list.push_back(scale_list[i].data_ptr()); - amax_dptr_list.push_back(amax_list[i].data_ptr()); - scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr()); + if (with_fused_kernel) { + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); + } else { + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], + at::cuda::getCurrentCUDAStream()); + } } - - fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, - transposed_output_list, amax_dptr_list, scale_inv_dptr_list, - otype); + return py_output_objects_list; } -std::tuple, std::vector> fused_multi_cast_transpose_alloc( - std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - std::vector scale_indices, std::vector amax_indices, - std::vector scale_inv_indices, transformer_engine::DType otype) { - using namespace transformer_engine; - - std::vector cast_output_list; - std::vector transposed_output_list; - std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; - for (size_t i = 0; i < input_list.size(); ++i) { - auto input_i = input_list[i]; - // construct cast output tensors - auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte); - cast_output_list.push_back(cast_output_i); - // construct transposed output tensors - auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte); - transposed_output_list.push_back(transposed_output_i); - // construct amax/scale/scale_inv dptr lists - amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i])); - scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i])); - scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i])); - } - - fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, - transposed_output_list, amax_dptr_list, scale_inv_dptr_list, - otype); +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output) { + using namespace transformer_engine::pytorch; - return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list)); -} + const auto dim = input.dim(); + NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; + if (input.dim() > 2) { + input = input.view({-1, input.size(dim - 1)}); + } size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto output = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); - if (M == 0 || N == 0) return output; - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); + at::Tensor out; + if (output.has_value()) { + out = *output; + } else { + out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + } + if (M == 0 || N == 0) return out; auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - nvte_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); + return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp new file mode 100644 index 0000000000..d2607e4ed0 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" + +namespace transformer_engine::pytorch { +namespace detail { + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { + const at::Tensor &data = tensor.attr("_data").cast(); + const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast(); + float *scale_inv_dptr = reinterpret_cast(scale_inv.data_ptr()); + const DType dtype = tensor.attr("_fp8_dtype").cast(); + + const auto &shape = getTensorShape(data); + + bool transpose_valid = !tensor.attr("_transpose_invalid").cast(); + std::optional transpose = std::nullopt; + if (transpose_valid) { + transpose = tensor.attr("_transpose").cast>(); + } + + auto ret = TensorWrapper(quantizer->get_scaling_mode()); + + ret.set_rowwise_data(data.data_ptr(), dtype, shape); + if (transpose_valid && transpose != std::nullopt) { + const auto &transpose_shape = getTensorShape(*transpose); + ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); + } + + const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto scale_inv_shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + quantizer->set_quantization_params(&ret); + return ret; +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape); + } + + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0, + scale_inv_colwise_shape); + } + + quantizer->set_quantization_params(&ret); + return ret; +} + +} // namespace detail + +} // namespace transformer_engine::pytorch diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.h b/transformer_engine/pytorch/csrc/extensions/util.cpp old mode 100755 new mode 100644 similarity index 53% rename from tests/pytorch/custom_ort_ops/custom_op_library.h rename to transformer_engine/pytorch/csrc/extensions/util.cpp index 747e6c5083..5f49383d11 --- a/tests/pytorch/custom_ort_ops/custom_op_library.h +++ b/transformer_engine/pytorch/csrc/extensions/util.cpp @@ -4,15 +4,11 @@ * See LICENSE for license information. ************************************************************************/ -#pragma once -#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "util.h" -#ifdef __cplusplus -extern "C" { -#endif +#include "ATen/cuda/CUDAContextLight.h" -ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); - -#ifdef __cplusplus +bool non_tn_fp8_gemm_supported() { + int major = at::cuda::getCurrentDeviceProperties()->major; + return major >= 10; } -#endif diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h new file mode 100644 index 0000000000..0679528b94 --- /dev/null +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -0,0 +1,73 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#define PYBIND11_DETAILED_ERROR_MESSAGES // TODO remove + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#include +#include +#include +#include + +#include "common.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +extern PyTypeObject *Float8TensorPythonClass; +extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8QuantizerClass; +extern PyTypeObject *MXFP8TensorPythonClass; +extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8QuantizerClass; + +void init_extension(); + +void init_float8_extension(); + +void init_mxfp8_extension(); + +namespace detail { + +inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } + +inline bool IsFloat8Tensor(PyObject *obj) { + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; +} + +inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } + +inline bool IsMXFP8Tensor(PyObject *obj) { + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; +} + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); + +template +std::unique_ptr CreateQuantizer(const py::handle quantizer) { + return std::make_unique(quantizer); +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params); + +std::unique_ptr CreateMXFP8Params(const py::handle params); + +inline bool IsFloatingPointType(at::ScalarType type) { + return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; +} + +constexpr std::array custom_types_converters = { + std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + CreateQuantizer)}; + +} // namespace detail + +} // namespace transformer_engine::pytorch + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp deleted file mode 100644 index 203b575a0d..0000000000 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ /dev/null @@ -1,414 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common/util/cuda_runtime.h" -#include "common/util/system.h" -#include "extensions.h" - -namespace { -transformer_engine::DType reverse_map_dtype(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } -} -} // namespace - -at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = - cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); - return output; -} - -at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &scale, - at::Tensor output, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, - fp8_tensor); - return output; -} - -at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv, - int64_t fp8_tensor, int64_t itype, int64_t otype) { - transformer_engine::DType itype_arg = reverse_map_dtype(itype); - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); - return output; -} - -at::Tensor gelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = gelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor relu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = relu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor reglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = reglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor geglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = geglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor swiglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = swiglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor qgelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = qgelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor srelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = srelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - int64_t A_type, int64_t transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, int64_t B_type, int64_t transb, at::Tensor D, - at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, at::Tensor bias, - int64_t bias_type, at::Tensor pre_gelu_out, int64_t grad, - at::Tensor workspace, int64_t workspaceSize, int64_t accumulate, - int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - te_gemm(A, A_scale_inverse, A_type_arg, transa_arg, B, B_scale_inverse, B_type_arg, transb_arg, D, - D_scale, D_type_arg, D_amax, bias, bias_type_arg, pre_gelu_out, grad_arg, workspace, - workspaceSize_arg, accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -std::vector te_grouped_gemm_ts( - std::vector A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type, - int64_t transa, std::vector B, at::Tensor B_scale_inverse, int64_t B_offset, - int64_t B_type, int64_t transb, std::vector D, int64_t D_offset, at::Tensor D_scale, - int64_t D_type, at::Tensor D_amax, std::vector bias, int64_t bias_type, - std::vector pre_gelu_out, int64_t grad, std::vector workspace, - int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, - B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias, - bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg, - accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -at::Tensor te_grouped_gemm_single_output_ts( - std::vector A, std::vector A_scale_inverse, int64_t A_offset, - int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, - int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, - int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, - std::vector bias, int64_t bias_type, std::vector pre_gelu_out, - int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, - int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, - B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, - D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, - pre_gelu_out, grad_arg, workspace, workspaceSize_arg, - accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, double eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, - int64_t otype, const int64_t sm_margin, - const bool zero_centered_gamma) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - float eps_float = static_cast(eps); - - at::Tensor output = layernorm_fwd_fp8_inf(input, weight, bias, eps_float, scale, amax, scale_inv, - otype_arg, sm_margin, zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset - fp8_tensor); // scale_inv_offset - - return output; -} - -at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, double eps, const int64_t sm_margin, - const bool zero_centered_gamma) { - float eps_float = static_cast(eps); - - at::Tensor output = - layernorm_fwd_inf(input, weight, bias, eps_float, sm_margin, zero_centered_gamma); - - return output; -} - -at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype, const int64_t sm_margin, - const bool zero_centered_gamma) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - float eps_float = static_cast(eps); - - at::Tensor output = rmsnorm_fwd_fp8_inf(input, weight, eps_float, scale, amax, scale_inv, - otype_arg, sm_margin, zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset - fp8_tensor); // scale_inv_offset - - return output; -} - -at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, - const int64_t sm_margin, const bool zero_centered_gamma) { - float eps_float = static_cast(eps); - - at::Tensor output = rmsnorm_fwd_inf(input, weight, eps_float, sm_margin, zero_centered_gamma); - - return output; -} - -TORCH_LIBRARY(tex_ts, m) { - m.def("cast_to_fp8_ts", &cast_to_fp8_ts); - m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts); - m.def("cast_from_fp8_ts", &cast_from_fp8_ts); - m.def("gelu_ts", &gelu_ts); - m.def("relu_ts", &relu_ts); - m.def("geglu_ts", &geglu_ts); - m.def("reglu_ts", ®lu_ts); - m.def("swiglu_ts", &swiglu_ts); - m.def("qgelu_ts", &qgelu_ts); - m.def("srelu_ts", &srelu_ts); - m.def("te_gemm_ts", &te_gemm_ts); - m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); - m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); - m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); - m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); - m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); - m.def("rmsnorm_fwd_inf_ts", &rmsnorm_fwd_inf_ts); -} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h new file mode 100644 index 0000000000..cbdf0833ed --- /dev/null +++ b/transformer_engine/pytorch/csrc/util.h @@ -0,0 +1,12 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ + +bool non_tn_fp8_gemm_supported(); + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e6d63ab9e4..aa5964bc4a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -7,6 +7,7 @@ from contextlib import contextmanager, AbstractContextManager, ContextDecorator from functools import lru_cache +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -20,7 +21,11 @@ from .utils import safely_set_viewless_tensor_data from .constants import dist_group_type from .fp8 import FP8GlobalStateManager -from .float8_tensor import Float8Tensor +from .tensor.float8_tensor import Float8Quantizer, Float8Tensor +from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from .tensor.quantized_tensor import QuantizedTensor, Quantizer +from .tensor._internal.float8_tensor_base import Float8TensorBase +from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -815,7 +820,7 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. @@ -836,57 +841,232 @@ def reduce_scatter_along_first_dim( return output, handle +def _all_gather_fp8( + input_: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: Optional[Float8Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: + """All-gather FP8 tensor along first dimension.""" + world_size = get_distributed_world_size(process_group) + + # Output tensor dims + if out_shape is None: + out_shape = list(input_.size()) + out_shape[0] *= world_size + + # Quantize input tensor if needed + if not isinstance(input_, Float8TensorBase): + assert isinstance(quantizer, Float8Quantizer) + init_columnwise_usage = quantizer.columnwise_usage + quantizer.set_usage(columnwise=False) + input_ = quantizer(input_) + quantizer.set_usage(columnwise=init_columnwise_usage) + + # Construct output tensor + out: Float8TensorBase + if isinstance(quantizer, Float8Quantizer): + dtype = torch.float32 + device = "cuda" + if isinstance(input_, Float8Tensor): + dtype = input_.dtype + device = input_.device + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + elif isinstance(input, Float8Tensor): + out = input_.make_like(input_, shape=out_shape) + out._data = torch.empty_like( + out_shape, + dtype=torch.uint8, + device=input_.device, + ) + out._transpose = None + out._transpose_invalid = True + else: + raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + out._scale_inv = input_._scale_inv + + # Perform communication + handle = torch.distributed.all_gather_into_tensor( + out._data, + input_._data.contiguous(), + group=process_group, + async_op=async_op, + ) + + # Make sure FP8 transpose is populated if needed + if out._transpose is not None: + if handle is not None: + handle.wait() + handle = None + if not isinstance(out, Float8Tensor): + raise RuntimeError("FP8TensorBase does not support FP8 transpose yet") + out._create_transpose() + + return out, handle + + +def _all_gather_mxfp8( + input_: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: MXFP8Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: + """All-gather MXFP8 tensor along first dimension.""" + + # Tensor dims + world_size = get_distributed_world_size(process_group) + in_shape = list(input_.size()) + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage and not quantizer.columnwise_usage: + + # Cast input tensor to MXFP8 if needed + if not isinstance(input_, MXFP8TensorBase): + input_ = quantizer(input_) + + # Construct MXFP8 output tensor + dtype = torch.float32 + device = "cuda" + if isinstance(input_, MXFP8Tensor): + dtype = input_.dtype + device = input_.device + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = input_._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as coalescing_manager: + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + input_._rowwise_data, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = coalescing_manager if async_op else None + return out, handle + + # Gather in high precision and quantize for column-wise usage + if isinstance(input_, QuantizedTensor): + input_ = input_.dequantize(dtype=torch.bfloat16) + out = torch.empty( + out_shape, + dtype=input_.dtype, + device=input_.device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, input_, group=process_group) + out = quantizer(out) + return out, None + + def gather_along_first_dim( input_: torch.Tensor, process_group: dist_group_type, async_op: bool = False, -) -> tuple[torch.Tensor, Any]: + quantizer: Optional[Quantizer] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """All-gather tensors and concatenate along first dimension.""" # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: + if quantizer is not None and not isinstance(input_, QuantizedTensor): + input_ = quantizer(input_) return input_, None - # Allocate output tensor - output_shape = list(input_.size()) - output_shape[0] *= world_size - if isinstance(input_, Float8Tensor): - output = Float8Tensor.make_like( + # Output tensor dims + out_shape = list(input_.size()) + out_shape[0] *= world_size + + # FP8 case + if isinstance(input_, Float8TensorBase) or isinstance(quantizer, Float8Quantizer): + return _all_gather_fp8( input_, - data=torch.empty( - output_shape, - dtype=torch.uint8, - device=input_.device, - ), + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, ) - src = input_._data.contiguous() - dst = output._data - else: - output = torch.empty( - output_shape, + + # MXFP8 case + if isinstance(input_, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + assert isinstance(quantizer, MXFP8Quantizer) + return _all_gather_mxfp8( + input_, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + + # High-precision communication for quantized tensors + if quantizer is not None: + warnings.warn( + "Attempting to all-gather an unsupported quantized tensor. " + "Falling back to high-precision all-gather." + ) + if isinstance(input_, QuantizedTensor): + input_ = input_.dequantize() + out = torch.empty( + out_shape, dtype=input_.dtype, device=input_.device, memory_format=torch.contiguous_format, ) - src = input_.contiguous() - dst = output + torch.distributed.all_gather_into_tensor(out, input_, group=process_group) + out = quantizer(out) + return out, None - # Launch all-gather + # Dequantize quantized tensor if not supported + if isinstance(input_, QuantizedTensor): + warnings.warn( + "Attempting to all-gather an unsupported quantized tensor. " + "Falling back to high-precision all-gather." + ) + input_ = input_.dequantize() + + # Communication for plain PyTorch tensors + out = torch.empty( + out_shape, + dtype=input_.dtype, + device=input_.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - dst, - src, + out, + input_.contiguous(), group=process_group, async_op=async_op, ) - return output, handle + return out, handle def allreduce( input_: torch.Tensor, tp_group: Optional[dist_group_type] = None, async_op: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. @@ -907,12 +1087,13 @@ def _fsdp_scatter_tensors( if fsdp_group is not None: for t in tensors: if isinstance(t, torch.Tensor): - target = t._data if isinstance(t, Float8Tensor) else t - shapes.append(target.data.shape) - safely_set_viewless_tensor_data( - target, - split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), - ) + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + shapes.append(target.data.shape) + safely_set_viewless_tensor_data( + target, + split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), + ) else: shapes.append(None) return shapes @@ -928,10 +1109,11 @@ def _fsdp_gather_tensors( for s, t in zip(shapes, tensors): if isinstance(t, torch.Tensor): assert s is not None, "Internal TE error." - target = t._data if isinstance(t, Float8Tensor) else t - safely_set_viewless_tensor_data( - target, gather_split_1d_tensor(target.data, fsdp_group).view(s) - ) + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + safely_set_viewless_tensor_data( + target, gather_split_1d_tensor(target.data, fsdp_group).view(s) + ) def _is_te_module(module): diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py deleted file mode 100755 index 79b839edfd..0000000000 --- a/transformer_engine/pytorch/export.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Export utilities for TransformerEngine""" -from contextlib import contextmanager - -_IN_ONNX_EXPORT_MODE = False - - -@contextmanager -def onnx_export( - enabled: bool = False, -) -> None: - """ - Context manager for exporting to ONNX. - - .. code-block:: python - - with onnx_export(enabled=True): - torch.onnx.export(model) - - Parameters - ---------- - enabled: bool, default = `False` - whether or not to enable export - """ - - global _IN_ONNX_EXPORT_MODE - onnx_export_state = _IN_ONNX_EXPORT_MODE - try: - _IN_ONNX_EXPORT_MODE = enabled - yield - finally: - _IN_ONNX_EXPORT_MODE = onnx_export_state - - -def is_in_onnx_export_mode() -> bool: - """Returns True if onnx export mode is enabled, False otherwise.""" - return _IN_ONNX_EXPORT_MODE diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8554cc7443..a771e3bb75 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -4,6 +4,6 @@ """Tensor class with FP8 data""" -from .tensor import Float8Tensor +from .tensor.float8_tensor import Float8Tensor __all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b1b6165777..254bcf12e1 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -3,6 +3,9 @@ # See LICENSE for license information. """FP8 utilities for TransformerEngine""" +from __future__ import annotations + +import abc import os from contextlib import contextmanager from collections import deque @@ -10,7 +13,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling from .constants import dist_group_type from .utils import get_device_compute_capability @@ -33,12 +36,21 @@ def check_fp8_support() -> Tuple[bool, str]: return True, "" -def get_default_fp8_recipe() -> DelayedScaling: +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for MXFP8 execution." + + +def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return MXFP8BlockScaling() return DelayedScaling() -def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype: +def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -47,7 +59,7 @@ def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) - return torch.float8_e5m2fn -def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -56,7 +68,7 @@ def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> t return tex.DType.kFloat8E5M2 -def get_fp8_max(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -81,7 +93,6 @@ class FP8GlobalStateManager: global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} - global_scale_inv_buffer = {} fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" @@ -89,6 +100,8 @@ class FP8GlobalStateManager: autocast_to_fp8_params = {} fp8_param_to_autocast = {} skip_fp8_weight_update_tensor = None + mxfp8_available = None + reason_for_no_mxfp8 = "" @classmethod def reset(cls) -> None: @@ -104,12 +117,15 @@ def reset(cls) -> None: cls.global_amax_buffer = {} cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} - cls.global_scale_inv_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} + cls.autocast_to_fp8_params = {} + cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None + cls.mxfp8_available = None + cls.reason_for_no_mxfp8 = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -130,6 +146,13 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 + @classmethod + def is_mxfp8_available(cls) -> Tuple[bool, str]: + """Return if MXFP8/current scaling support is available.""" + if cls.mxfp8_available is None: + cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() + return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -154,7 +177,7 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_recipe: DelayedScaling, + fp8_recipe: Recipe, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" @@ -188,6 +211,9 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ + if fp8_meta["recipe"].mxfp8(): + return + # Every module must call this function exactly once since # the amax tensors are static. Ensures that compatibility # with non-graphed modules is maintained. @@ -208,14 +234,12 @@ def add_fp8_tensors_to_global_buffer( cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] else: cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @@ -249,7 +273,7 @@ def is_first_fp8_module(cls): return tmp @classmethod - def get_fp8_recipe(cls) -> DelayedScaling: + def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" if cls.FP8_RECIPE is not None: return cls.FP8_RECIPE @@ -261,7 +285,7 @@ def get_fp8_group(cls) -> Union[dist_group_type, None]: return cls.FP8_DISTRIBUTED_GROUP @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]: + def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: """FP8 autocast state getter""" return ( cls.FP8_ENABLED, @@ -335,7 +359,6 @@ def reduce_and_update_fp8_tensors( contiguous_amax, cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], - cls.global_scale_inv_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -343,19 +366,18 @@ def reduce_and_update_fp8_tensors( else: split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - for amax_history, scale, scale_inv in zip( + for amax_history, scale in zip( cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], - cls.global_scale_inv_buffer[buffer_key], ): _amax_and_scale_update( - amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe + amax_history, scale, get_fp8_max(recipe, forward), recipe ) @classmethod def get_unique_autocast_key( cls, - recipe: Optional[DelayedScaling] = None, + recipe: Optional[Recipe] = None, group: Optional[dist_group_type] = None, ): """ @@ -369,7 +391,7 @@ def fp8_autocast_enter( cls, enabled: bool = False, calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: @@ -392,6 +414,9 @@ def fp8_autocast_enter( if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() assert fp8_available, reason_for_no_fp8 + if isinstance(fp8_recipe, MXFP8BlockScaling): + mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() + assert mxfp8_available, reason_for_no_mxfp8 @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -408,12 +433,15 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - """Copy the scaling factors and amaxes for recompute forward phase to ensure both forward steps are numerically same. """ + + if fp8_meta["recipe"].mxfp8(): + return + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" to_copy = [ fp8_meta["scaling_fwd"].amax_history.clone(), fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), ] if buffer_position_key in fp8_meta: @@ -432,10 +460,12 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non 1 forward for indentical numerical outputs. """ + if fp8_meta["recipe"].mxfp8(): + return + # Store updated amaxes and scales from phase 1 post forward. fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale - fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -444,18 +474,20 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non # Replace amaxes and scales with stashed values for phase 2 forward fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) - fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" + + if fp8_meta["recipe"].mxfp8(): + return + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"]) @contextmanager -def fp8_model_init(enabled: bool = True) -> None: +def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None: """ Context manager for FP8 initialization of parameters. @@ -479,22 +511,27 @@ def fp8_model_init(enabled: bool = True) -> None: precision copies of weights are already present in the optimizer. * inference, where only the FP8 copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. + recipe: transformer_engine.common.recipe.Recipe, default = `None` + Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe @contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: @@ -529,7 +566,7 @@ def fp8_autocast( data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` + fp8_recipe: recipe.Recipe, default = `None` recipe used for FP8 training. fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the fp8 tensors @@ -639,7 +676,6 @@ def _compute_scaling_factor( def _amax_and_scale_update( amax_history: torch.Tensor, scale: torch.Tensor, - scale_inv: torch.Tensor, fp8_max: float, recipe: DelayedScaling, ) -> None: @@ -650,7 +686,6 @@ def _amax_and_scale_update( ) new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) scale.copy_(new_scale) - scale_inv.copy_(1.0 / new_scale) amax_history.copy_(new_amax_history) @@ -662,3 +697,152 @@ def split_and_copy( """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" splits = buffer.split(chunk_sizes) torch._foreach_copy_(outputs, splits) + + +class RecipeState(abc.ABC): + """Configuration and state for a quantization recipe. + + This is a builder class for quantizers, which are in turn builder + classes for quantized tensors. + + This class may pack together the state for multiple quantizers, + which is helpful for applying fused kernels with less overhead. + + """ + + @staticmethod + def create( + recipe: Recipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> RecipeState: + """Factory method to create the state for a quantization recipe + + Parameters + ---------- + recipe: Recipe + Quantization recipe. + mode: {"forward", "backward"} + Training stage where quantization will be performed. + num_quantizers: int, default = 1 + Number of quantizers to create state for. + device: torch.device, default = default CUDA device + Device for quantized tensors. + + Returns + ------- + RecipeState: + Quantization recipe state. + + """ + + cls = None + if recipe.delayed(): + cls = DelayedScalingRecipeState + elif recipe.mxfp8(): + cls = MXFP8BlockScalingRecipeState + else: + raise ValueError("{recipe.__class__.__name__} is not supported") + return cls( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + + @abc.abstractmethod + def make_quantizers(self) -> list: + """Convert recipe state to quantizers. + + Quantizers are builder classes for quantized tensors. They are + typically used to convert a high-precision tensor (e.g. in + FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + +class DelayedScalingRecipeState(RecipeState): + """State for FP8 quantization with per-tensor delayed scaling. + + Delayed scaling recipe requires a scaling factor (applied when + casting to FP8) and a history of max-abs values ("amax") from + recent FP8 casts for updating the scaling factor. The scale update + is handled externally by `FP8GlobalStateManager`. + + """ + + recipe: DelayedScaling + mode: str + dtype: tex.DType + scale: torch.Tensor + amax_history: torch.Tensor + + def __init__( + self, + recipe: DelayedScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) + self.amax_history = torch.zeros( + recipe.amax_history_len, + num_quantizers, + dtype=torch.float32, + device=device, + ) + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_tensor import Float8Quantizer + + return [ + Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) + for i in range(self.num_quantizers) + ] + + +class MXFP8BlockScalingRecipeState(RecipeState): + """Configuration for MXFP8 quantization. + + MXFP8 quantization does not require state. + + """ + + recipe: MXFP8BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: MXFP8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.mxfp8_tensor import MXFP8Quantizer + + return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 3853e70048..83b316aad4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -11,7 +11,7 @@ from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.common.recipe import DelayedScaling, Recipe from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, @@ -556,12 +556,16 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: DelayedScaling, -) -> List[Any]: + fp8_recipe: Recipe, +) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ + + if not isinstance(fp8_recipe, DelayedScaling): + return None + fp8_tensors = [] for module in modules: for m in module.modules(): @@ -579,9 +583,13 @@ def save_fp8_tensors( def restore_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_tensors: List[Any], + fp8_tensors: Optional[List[Any]], ) -> None: """Restore FP8 tensors.""" + + if fp8_tensors is None: + return + for module in modules: for m in module.modules(): module_tensors = fp8_tensors.pop(0) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 2be291e4f9..cd18808465 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,29 +4,27 @@ """Internal function used by multiple modules.""" -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +import os +from typing import Any, List, Optional, Tuple, Union, Callable from dataclasses import dataclass +from functools import reduce +from operator import mul as multiply_op import torch from .. import cpp_extensions as tex -from ..export import is_in_onnx_export_mode -from ..fp8 import get_fp8_te_dtype +from ..constants import TE_DType from ..utils import get_default_init_method +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.mxfp8_tensor import MXFP8Quantizer +_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) -def _get_normalization_func( - normalization: str, fp8_output: bool, is_grad_enabled: bool, forward: bool -): + +def _get_normalization_func(normalization: str, forward: bool): fwd_normalization_funcs = { - ("LayerNorm", True, True): tex.layernorm_fwd_fp8, - ("LayerNorm", True, False): tex.layernorm_fwd_fp8_inf, - ("LayerNorm", False, True): tex.layernorm_fwd_noalloc, - ("LayerNorm", False, False): tex.layernorm_fwd_inf, - ("RMSNorm", True, True): tex.rmsnorm_fwd_fp8, - ("RMSNorm", True, False): tex.rmsnorm_fwd_fp8_inf, - ("RMSNorm", False, True): tex.rmsnorm_fwd_noalloc, - ("RMSNorm", False, False): tex.rmsnorm_fwd_inf, + "LayerNorm": tex.layernorm_fwd, + "RMSNorm": tex.rmsnorm_fwd, } bwd_normalization_funcs = { "LayerNorm": tex.layernorm_bwd, @@ -34,81 +32,79 @@ def _get_normalization_func( } if forward: - return fwd_normalization_funcs[(normalization, fp8_output, is_grad_enabled)] - assert not fp8_output, "FP8 output is not supported in backward normalization!" - assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!" + return fwd_normalization_funcs[normalization] return bwd_normalization_funcs[normalization] -def _apply_normalization( +def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor: + """Reorder FP8 transposes after Userbuffers gather. + + The all-gather is performed in-place in the Float8Tensor's + row-wise data, and afterwards we need to do a transpose to get the + correct ordering. This misuses data fields in Float8Tensor and + should be considered an evil hack. It would be best to move + transpose logic into CommOverlap::get_buffer. + + Responsibility for fixing: adener, tmoon + + """ + assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor" + assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1" + assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data" + assert ( + fp8_tensor._data.shape[0] % tp_size == 0 + ), "Leading dimension of data is not divisble by TP size" + + data = fp8_tensor._data + batched_size = reduce(multiply_op, data.shape[1:]) + interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size] + transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size] + fp8_tensor._transpose = ( + data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape) + ) + + fp8_tensor._transpose_invalid = False + fp8_tensor._data = None + + return fp8_tensor + + +def apply_normalization( inputmat: torch.Tensor, ln_out: torch.Tensor, ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], eps: float, - fp8_out: bool, - fp8_meta: Dict[str, Any], + output_quantizer, + output_dtype, normalization: str, fwd_ln_sm_margin: int, zero_centered_gamma: bool, - is_grad_enabled: bool, - fp8_scale: Optional[torch.Tensor] = None, - fp8_amax: Optional[torch.Tensor] = None, - fp8_scale_inv: Optional[torch.Tensor] = None, ): - normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) + """Apply normalization to input.""" + normalization_func = _get_normalization_func(normalization, True) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) - if fp8_out: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - - if is_grad_enabled: - output_key = "ln_out" if normalization == "LayerNorm" else "rmsnorm_out" - output_kwarg = {output_key: ln_out} - output = normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - scale=fp8_scale, - amax=fp8_amax, - scale_inv=fp8_scale_inv, - **output_kwarg, - ) - else: - return ( - normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - scale=fp8_scale, - amax=fp8_amax, - scale_inv=fp8_scale_inv, - ), - None, - None, - ) - else: - if is_grad_enabled: - output = normalization_func(*inputs, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma) - else: - return ( - normalization_func(*inputs, eps, fwd_ln_sm_margin, zero_centered_gamma), - None, - None, - ) - if normalization == "RMSNorm": - output = (ln_out, None, output[1]) - elif normalization == "LayerNorm": - output = (ln_out, output[1], output[2]) - return output + + split_mxfp8_cast = False + if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer): + split_mxfp8_cast = True + + output = normalization_func( + *inputs, + eps, + None if split_mxfp8_cast else ln_out, + None if split_mxfp8_cast else output_quantizer, + TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, + fwd_ln_sm_margin, + zero_centered_gamma, + ) + + return ( + (output_quantizer.quantize(output[0], out=ln_out), *output[1:]) + if split_mxfp8_cast + else output + ) class _NoopCatFunc(torch.autograd.Function): @@ -202,7 +198,7 @@ def backward( return None, *grad_inputs -def _noop_cat( +def noop_cat( tensors: List[torch.Tensor], dim: int = 0, ) -> torch.Tensor: @@ -217,8 +213,6 @@ def _noop_cat( raise ValueError("Attempted to concatenate 0 tensors") if len(tensors) == 1: return tensors[0] - if is_in_onnx_export_mode(): - return torch.cat(tensors, dim=dim) return _NoopCatFunc.apply(dim, *tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8de0b733a9..d0f9525135 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -18,12 +18,14 @@ import torch.nn.functional as F import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe + from ._common import _ParameterInitMeta -from ..export import is_in_onnx_export_mode from ..fp8 import ( - get_default_fp8_recipe, - get_fp8_te_dtype, + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, FP8GlobalStateManager, + RecipeState, ) from ..distributed import ( gather_along_first_dim, @@ -31,13 +33,10 @@ in_fp8_activation_recompute_phase, _fsdp_gather_tensors, ) -from ..cpp_extensions import ( - fp8_cast_transpose_fused, - fp8_cast_transpose_bgrad_fused, - cast_to_fp8, -) from ..constants import dist_group_type -from ..float8_tensor import Float8Tensor +from ..tensor import QuantizedTensor, Quantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -48,6 +47,7 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 +_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] @@ -295,34 +295,43 @@ def get_method(name): raise KeyError(f"Given layer name {name} does not exist.") def get_default_config(name): + global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY method = get_method(name) is_reduce_scatter = name in layers_reduce_scatter_overlap + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() default_cfg = { "method": method, "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": False, - "num_splits": 4 if method == "pipeline" else tp_size, + "set_sm_margin": not method == "ring_exchange", + "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, "use_ce": True, "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + "pipeline_rs_overlap_first_gemm": False, } return default_cfg def add_ub( name: str, method: str, - is_reduce_scatter: int, + is_reduce_scatter: bool, num_sm: int = 16, cga_size: int = 2, - set_sm_margin: int = 0, + set_sm_margin: bool = False, num_splits: int = 0, - aggregate: int = 0, - atomic_gemm: int = 0, + aggregate: bool = False, + atomic_gemm: bool = False, use_ce: bool = True, fp8_buf: bool = False, + comm_priority: int = 0, + gemm_priority: int = 0, + pipeline_rs_overlap_first_gemm: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -373,6 +382,8 @@ def add_ub( atomic_gemm=atomic_gemm, use_ce=use_ce, aggregate=aggregate, + gemm_priority=gemm_priority, + comm_priority=comm_priority, ) else: ub_obj = tex.CommOverlap( @@ -386,6 +397,9 @@ def add_ub( num_comm_sm=num_sm, set_sm_margin=set_sm_margin, atomic_gemm=atomic_gemm, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) _ub_communicators[name] = ub_obj @@ -439,8 +453,8 @@ def __init__(self) -> None: self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta_tensors_initialized = False + self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} self.tp_group = None self.tp_size = 1 self.sequence_parallel = False @@ -448,7 +462,7 @@ def __init__(self) -> None: self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.fsdp_wrapped = False self.fsdp_group = None - self._fp8_workspaces: Dict[str, Float8Tensor] = {} + self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None # Names of attributes that can be set quickly (see __setattr__ @@ -499,6 +513,9 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) ) + # Update quantizers with new amax pointers. + self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() + # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[ @@ -516,37 +533,38 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history ) - def set_meta_tensor(self, fwd: bool) -> None: + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + # Return early if recipe state matches recipe if self.fp8_meta_tensors_initialized: - # Handle changed amax history size. - self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) - return + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) + return + if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 - self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() - self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros( - self.fp8_meta["recipe"].amax_history_len, - num_fp8_tensors, - dtype=torch.float32, - device="cuda", + # Initialize recipe state and quantizers + recipe_state = RecipeState.create( + recipe, + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors, ) - def init_fp8_meta_tensors(self) -> None: + self.fp8_meta[fp8_meta_tensor_key] = recipe_state + self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + + def init_fp8_meta_tensors(self, recipe: Recipe) -> None: """Init scales and amaxes.""" - self.set_meta_tensor(True) - self.set_meta_tensor(False) + self.set_meta_tensor(True, recipe) + self.set_meta_tensor(False, recipe) + self.fp8_meta_tensors_initialized = True def get_fp8_meta_tensors(self) -> None: @@ -559,7 +577,6 @@ def get_fp8_meta_tensors(self) -> None: with torch.no_grad(): for key in (fwd_key, bwd_key): fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) - fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) return fp8_meta_tensors @@ -570,17 +587,13 @@ def reset(key): if key in self.fp8_meta: if fp8_meta_tensors is None: self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) - self.fp8_meta[key].scale_inv.copy_( - torch.ones_like(self.fp8_meta[key].scale_inv) - ) self.fp8_meta[key].amax_history.copy_( torch.zeros_like(self.fp8_meta[key].amax_history) ) else: assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) - self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) - self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) with torch.no_grad(): reset("scaling_fwd") @@ -624,12 +637,12 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Copy tensors to CPU and store state = {} - state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) - state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) - state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv) - state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) - state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) - state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv) + state["recipe"] = self.fp8_meta["recipe"] + if state["recipe"].delayed(): + state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) + state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) + state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) + state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) # Store other pickelable values extra = {} @@ -667,12 +680,12 @@ def set_extra_state(self, state: torch.Tensor) -> None: # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] + self.fp8_meta["recipe"] = state["recipe"] if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: """Helper function to copy tensor from CPU @@ -684,12 +697,11 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load tensors - copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) - copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) - copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv) - copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) - copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) - copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv) + if self.fp8_meta["recipe"].delayed(): + copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) + copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) + copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) + copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) torch.cuda.synchronize() def set_activation_dtype(self, inp: torch.Tensor) -> None: @@ -729,7 +741,7 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" fp8_params = [] for param in self.parameters(recurse=False): - if isinstance(param, Float8Tensor) and param.requires_grad: + if isinstance(param, QuantizedTensor) and param.requires_grad: fp8_params.append(param) if len(fp8_params) == 0: return None @@ -742,22 +754,28 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors() - - if self.fp8 or self.fp8_calibration: - # FP8 init has already been run and recipe is the same, don't do anything. + if self.fp8_parameters or fp8_enabled: if ( self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] ): + # FP8 init has already been run and recipe is the same, don't do anything. return - - # Set FP8, recipe, and other FP8 metadata self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + + if fp8_enabled: + # Set FP8 and other FP8 metadata self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() @@ -766,17 +784,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.fp8_initialized = True - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + + self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() @contextmanager def prepare_forward( self, inp: torch.Tensor, - is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument num_gemms: int = 1, allow_non_contiguous: bool = False, ) -> Generator[torch.Tensor, None, None]: @@ -798,7 +814,7 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) - if self.fp8 and self.sequence_parallel: + if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( "Amax reduction across tensor parallel group is " "necessary when using sequence parallelism with FP8." @@ -838,110 +854,64 @@ def set_nccl_overlap_warning_if_tp(self) -> None: @staticmethod def grad_output_preprocess( - ctx, grad_output: torch.Tensor, row_parallel_mode: bool + ctx, + grad_output: torch.Tensor, + row_parallel_mode: bool, + quantizer: Optional[Quantizer], ) -> Tuple[Union[torch.Tensor, None], ...]: """Utility function for backward. Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output` in higher precision. - R2: gathered `grad_output` in FP8. - R3: R2 transposed. - R4: bias gradient on R1. + R1: gathered `grad_output`. + R2: bias gradient on R1. """ - if isinstance(grad_output, Float8Tensor): - grad_output._data = grad_output._data.contiguous() - else: - grad_output = grad_output.contiguous() - grad_output_mat = grad_output.view(-1, grad_output.shape[-1]) + grad_output = grad_output.reshape((-1, grad_output.shape[-1])) + grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - # No-FP8 case: bgrad is fused with wgrad for this case. + # Non-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: if not ctx.ub_overlap_ag: - grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) + grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) else: - ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) - grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1) - return grad_output_mat, None, None, None - - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - # FP8 case with non-FP8 wgrad - if gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - assert ( - not ctx.ub_overlap_ag - ), "override_linear_precision.wgrad not supported with UB AG overlap" - grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) - # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather - elif gather_grad_output: + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) + return grad_output, None + + # FP8 with all-gather: unfused bgrad, fused cast + transpose + if gather_grad_output: + grad_bias = None if ctx.use_bias: - grad_bias = grad_output_mat.sum(dim=0) - else: - grad_bias = None + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) if ctx.ub_overlap_ag: - grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) + # Quantize the gradient if needed + if not isinstance( + grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + ): + grad_output = quantizer(grad_output) + + # Copy into communication buffer, and replace original gradient with it + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) else: - grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) - if not isinstance(grad_output_mat, Float8Tensor): - cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out=grad_output_c, + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=quantizer, ) - else: - grad_output_c = grad_output_mat - if not ctx.ub_overlap_ag: - grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) - if not isinstance(grad_output_c, Float8Tensor): - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - else: - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) - grad_output_t = None + return grad_output, grad_bias - return grad_output_mat, grad_output_c, grad_output_t, grad_bias - - # FP8 case without gather: cast, transpose, bgrad fused + # FP8 without all-gather: fused bgrad + cast + transpose + grad_bias = None if ctx.use_bias: - grad_output_mat_no_fp8 = grad_output_mat - if isinstance(grad_output_mat, Float8Tensor): - grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype) - grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( - grad_output_mat_no_fp8, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if isinstance(grad_output_mat, Float8Tensor): - grad_output_c = grad_output_mat - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_c, grad_output_t = fp8_cast_transpose_fused( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) + if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_output_t = None - if not isinstance(grad_output_mat, Float8Tensor): - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - grad_output_c = grad_output_mat - grad_bias = None - - return grad_output_mat, grad_output_c, grad_output_t, grad_bias + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_output = quantizer(grad_output) + return grad_output, grad_bias def register_parameter(self, name, param, **kwargs): """ @@ -978,21 +948,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as Float8Tensor + # If primary weights are in fp8, wrap the parameter as FP8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: - dummy_amax = torch.empty( - (1, 1), - dtype=torch.float32, - device=param.device, - ) # Dummy buffer to avoid overwriting amax history - param = Float8Tensor.to_float8( - param, - fp8_meta=self.fp8_meta, - fp8_meta_index=fp8_meta_index, - amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), - ) + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + assert ( + quantizer is not None + ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + quantizer.internal = False + param = quantizer(param) # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but @@ -1004,17 +968,16 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: def forward(self): """Needs override.""" - def get_fp8_workspace( + def get_weight_workspace( self, *, tensor: Optional[torch.Tensor] = None, - fp8_meta_forward: Optional[bool] = None, - fp8_meta_index: Optional[int] = None, + quantizer: Optional[Quantizer] = None, cache_name: Optional[str] = None, update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: dist_group_type = None, - ) -> Float8Tensor: + fsdp_group: Optional[dist_group_type] = None, + ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values The workspace buffer may be cached for future function calls. @@ -1024,13 +987,9 @@ def get_fp8_workspace( tensor : torch.Tensor, optional Values to copy into workspace. Required if the workspace is being constructed or updated. - fp8_meta_forward: bool, optional - Whether to access FP8 meta tensors for the forward pass or - backward pass. Required if the workspace is being - constructed. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if the - workspace is being constructed. + quantizer: Quantizer, optional + Quantizer used to cast the weights. Required if the + workspace is being constructed or updated. cache_name: str, optional Key for caching. update_workspace: bool, default = `True` @@ -1052,61 +1011,24 @@ def get_fp8_workspace( # for models initialized with Fp8 primary weights. if ( out is not None - and not isinstance(out, Float8Tensor) + and tensor is not None and fsdp_group is not None - and out._data.shape != tensor.data.shape + and out.data.shape != tensor.data.shape ): _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) # Construct workspace if needed if out is None: - - # FP8 data - if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: + if tensor is None or quantizer is None: raise ValueError( - "tensor, fp8_meta_forward, and fp8_meta_index kwargs " - "must be provided to construct FP8 workspace" - ) - fp8_dtype = get_fp8_te_dtype( - self.fp8_meta["recipe"], - fprop_tensor=fp8_meta_forward, - ) - data = torch.empty_like(tensor, dtype=torch.uint8) - scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) - - # Transpose cache - with_transpose_cache = torch.is_grad_enabled() - if ( - not with_transpose_cache - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose_cache = True - data_transpose = None - if with_transpose_cache: - data_transpose = torch.empty( - (tensor.size(-1), tensor.numel() // tensor.size(-1)), - dtype=torch.uint8, - device=tensor.device, + "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - - # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_meta=self.fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, - dtype=tensor.dtype, - data_transpose=data_transpose, - ) + out = quantizer(tensor) # Update cache if cache_name is not None: self._fp8_workspaces[cache_name] = out - update_workspace = True - skip_update_flag = None + return out # Update workspace if needed if skip_update_flag is not None: @@ -1114,17 +1036,10 @@ def get_fp8_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if is_in_onnx_export_mode(): - # ONNX export does not support fused cast-transpose - # kernel and requires that FP8 scales can be - # represented with constant ops. - transpose_cache = out._transpose - out._transpose = None - out.quantize_(tensor) - out._scale_inv.fill_(out._scale_inv.item()) - out._transpose = transpose_cache - else: + if hasattr(out, "quantize_"): out.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, out, skip_update_flag) return out diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 1034398875..2549d45728 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -8,9 +8,8 @@ import torch -from ..cpp_extensions import ( - multi_padding_fused, -) +import transformer_engine_torch as tex + from ..jit import no_torch_dynamo @@ -36,7 +35,7 @@ def forward( total_row = sum(padded_m_splits) out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) - multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out) + tex.fused_multi_row_padding(inp.view(-1, in_features), out, m_splits, padded_m_splits) if is_grad_enabled: ctx.m_splits = m_splits diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index b0832b0848..479b91d396 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -8,9 +8,8 @@ import torch -from ..cpp_extensions import ( - multi_padding_fused, -) +import transformer_engine_torch as tex + from ..jit import no_torch_dynamo @@ -56,8 +55,8 @@ def backward(ctx, grad_output: torch.Tensor): [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device ) # FP8 pad input for forward, FP8 input transpose for backward wgrad - multi_padding_fused( - grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input + tex.fused_multi_row_padding( + grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits ) return (grad_input, None, None, None) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 65023e493b..2f9de58984 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """GroupedLinear API""" -from typing import Union, Optional, Callable, Tuple, List, Dict, Any +from typing import Union, Optional, Callable, Tuple, List import torch @@ -16,7 +16,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, @@ -28,21 +28,26 @@ from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( - cast_to_fp8, - fp8_cast_transpose_bgrad_fused, - fp8_multi_cast_transpose_fused, - fp8_grouped_gemm, - grouped_gemm, + general_grouped_gemm, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor import Float8Tensor, QuantizedTensor -from ..export import is_in_onnx_export_mode +from ..tensor.float8_tensor import Float8Tensor from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + + __all__ = ["GroupedLinear"] @@ -60,202 +65,141 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], + input_quantizers: List[Quantizer], + weight_quantizers: List[Quantizer], + output_quantizers: List[Quantizer], + grad_output_quantizers: List[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, sequence_parallel: bool, activation_dtype: torch.dtype, - fp8_meta_offsets: Dict[str, int], is_grad_enabled: bool, - weights_fp8: List[Union[Float8Tensor, None]], - *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], + module, + skip_fp8_weight_update, + *weights_and_biases, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] + device = inp.device + + # TODO Support MXFP8 # pylint: disable=fixme + if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): + raise NotImplementedError("GroupedLinear does not yet support MXFP8") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" inputmats = torch.split(inp.view(-1, in_features), m_splits) if fp8: - for i in range(num_gemms): - assert_dim_for_fp8_exec(inputmats[i]) - assert_dim_for_fp8_exec(weights[i]) + assert_dim_for_fp8_exec(*inputmats, *weights) # Cast input to expected dtype inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] - inputmats_t = [] - inputmat_scale_inv = None - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weights[0].requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - indices = list( - range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms) + weight_requires_grad = weights[0].requires_grad + + if input_quantizers[0] is not None: + for input_quantizer in input_quantizers: + input_quantizer.set_usage( + rowwise=True, + columnwise=(is_grad_enabled and weight_requires_grad), ) - inputmats, inputmats_t = fp8_multi_cast_transpose_fused( - inputmats_no_fp8, - fp8_meta["scaling_fwd"], - indices, # scale_indices - indices, # amax_indices - indices, # scale_inv_indices - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() ) - else: - # FP8 input for forward - inputmats = [ - cast_to_fp8( - inputmats_no_fp8[i], - fp8_meta["scaling_fwd"], - fp8_meta_offsets["input"] + i, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + if weight_quantizers[0] is not None: + for weight_quantizer in weight_quantizers: + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + if output_quantizers[0] is not None: + for output_quantizer in output_quantizers: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8: + inputmats = tex.fused_multi_quantize( + inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] + ) + weights_fp8 = [] + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + if not isinstance(weights[0], QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = module.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, ) - for i in range(num_gemms) - ] + weights_fp8.append(weight_fp8) + else: + weights_fp8 = weights - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 + bias_dtype = activation_dtype + weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases + biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases - # Use FP8 weights - if weights_fp8[0] is None: - weights_fp8 = weights - assert all(isinstance(w, Float8Tensor) for w in weights_fp8) - - out = torch.empty( - [sum(m_splits), weights_fp8[0].size(0)], - dtype=activation_dtype, - device=inputmats[0].device, - ) + out = torch.empty( + [sum(m_splits), weights_fp8[0].size(0)], + dtype=activation_dtype, + device=device, + ) - _ = fp8_grouped_gemm( - [w._data for w in weights_fp8], - [w._scale_inv for w in weights_fp8], - 0, # weight offset is 0 for the newly created _scale_inv - fp8_dtype_forward, - inputmats, - inputmat_scale_inv, - 0, - fp8_dtype_forward, - [out], - activation_dtype, - get_multi_stream_cublas_workspace(), - m_splits=m_splits, - bias=biases, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ) - else: - # Cast for native AMP - weights = [cast_if_needed(w, activation_dtype) for w in weights] - biases = ( - [cast_if_needed(bias, activation_dtype) for bias in biases] if use_bias else biases - ) + _ = general_grouped_gemm( + weights_fp8, + inputmats, + [out], + activation_dtype, + get_multi_stream_cublas_workspace(), + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=_2X_ACC_FPROP, + ) - if fp8_calibration: + if fp8_calibration: + for i in range(num_gemms): + # amax of input for i in range(num_gemms): - # amax of input - amin, amax = inputmats[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = ( - torch.max(-amin, amax).float() - ) - # amax of weight - amin, amax = weights[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = ( - torch.max(-amin, amax).float() - ) + input_quantizers[i].calibrate(inputmats[i]) + for i in range(num_gemms): + weight_quantizers[i].calibrate(weights[i]) - out = torch.empty( - [sum(m_splits), weights[0].size(0)], - dtype=activation_dtype, - device=inputmats[0].device, - ) + if is_grad_enabled: - _ = grouped_gemm( - weights, - inputmats, - torch.split(out, m_splits), - activation_dtype, - get_multi_stream_cublas_workspace(), - bias=biases, - use_bias=use_bias, - ) + saved_inputs, saved_weights = [], [] + ctx.weights_shape_1 = weights[0].shape[1] - if is_grad_enabled: - saved_inputmats = [None] * num_gemms - saved_inputmats_t = [None] * num_gemms - if weights[0].requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if not inputmats_t: - saved_inputmats = inputmats - else: - saved_inputmats_t = inputmats_t - if cpu_offloading: - for t in saved_inputmats_t: - t.activation_offloading = True - else: - saved_inputmats = inputmats_no_fp8 - - if cpu_offloading: - if fp8: - for w in weights_fp8: - if w is not None: - w.weight_offloading = True - for w in weights: - w.weight_offloading = True - for t in saved_inputmats: - if t is not None: - t.activation_offloading = True - - ctx.save_for_backward( - inputmat_scale_inv, - *saved_inputmats, - *saved_inputmats_t, - *weights, - *weights_fp8, - *[ - w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None - for w in weights - ], - ) + tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.weights_requires_grad = weights[0].requires_grad + ctx.device = device + ctx.saved_inputs = saved_inputs + ctx.saved_weights = saved_weights + ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.inp_shape = inp.shape - ctx.fp8_meta_offsets = fp8_meta_offsets ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): @@ -271,66 +215,42 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_GroupedLinear_backward"): - ( - inputmat_scale_inv, - *saved_tensors, - ) = ctx.saved_tensors - inputmats = saved_tensors[: ctx.num_gemms] - inputmats_t = saved_tensors[ctx.num_gemms : 2 * ctx.num_gemms] - weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] - weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] - main_grads = saved_tensors[4 * ctx.num_gemms :] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + N = ctx.num_gemms + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] + main_grads = saved_tensors[3 * N :] + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO for i in ctx.num_gemms: w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w # preprocess grad_output + grad_output = grad_output.contiguous() grad_output_mats = torch.split( grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits ) - grad_output_c = [None] * ctx.num_gemms - grad_output_t = [None] * ctx.num_gemms + grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) if ctx.use_bias: for i in range(ctx.num_gemms): - grad_biases[i], grad_output_c[i], grad_output_t[i] = ( - fp8_cast_transpose_bgrad_fused( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - ctx.fp8_meta_offsets["grad_output"] + i, - fp8_dtype_backward, - ) + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], ctx.grad_output_quantizers[i] ) else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - indices = list( - range( - ctx.fp8_meta_offsets["grad_output"], - ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms, - ) - ) - grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( - grad_output_mats, - ctx.fp8_meta["scaling_bwd"], - indices, # scale_indices - indices, # amax_indices - indices, # scale_inv_indices - fp8_dtype_backward, - ) - else: - for i in range(ctx.num_gemms): - grad_output_c[i] = cast_to_fp8( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - ctx.fp8_meta_offsets["grad_output"] + i, - fp8_dtype_backward, - ) + grad_output = tex.fused_multi_quantize( + grad_output_mats, + None, + ctx.grad_output_quantizers, + TE_DType[ctx.activation_dtype], + ) + else: + grad_output = grad_output_mats if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -340,111 +260,57 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.requires_dgrad: - if ctx.fp8: - dgrad = torch.empty( - (sum(ctx.m_splits), weights_fp8[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - fp8_grouped_gemm( - [w.transpose_2d() for w in weights_fp8], - [w._scale_inv for w in weights_fp8], - 0, # weight offset is 0 for the newly created _scale_inv - weights_fp8[0]._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - ctx.fp8_meta_offsets["grad_output"], - fp8_dtype_backward, - [dgrad], - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - m_splits=ctx.m_splits, - use_split_accumulator=_2X_ACC_DGRAD, - ) - else: - dgrad = torch.empty( - (sum(ctx.m_splits), weights[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - grouped_gemm( - weights, - grad_output_mats, - torch.split(dgrad, ctx.m_splits), - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NN", - grad=True, - ) + dgrad = torch.empty( + (sum(ctx.m_splits), ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) - if weights[0].requires_grad: + general_grouped_gemm( + weights, + grad_output, + torch.split(dgrad, ctx.m_splits), + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=_2X_ACC_DGRAD, + ) + + if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation: wgrad_list = [w.main_grad for w in weights] else: wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) + torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - if ctx.fp8: - # WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmats_t[0] is None: - for i in range(ctx.num_gemms): - if isinstance(inputmats[i], Float8Tensor): - inputmats_t[i] = inputmats[i].transpose_2d() - else: - inputmats_t[i] = tex.fp8_transpose( - inputmats[i], fp8_dtype_backward - ) - fp8_grouped_gemm( - [ - inp._data if isinstance(inp, Float8Tensor) else inp - for inp in inputmats_t - ], - [inputmat_scale_inv], - 0, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - ctx.fp8_meta_offsets["grad_output"], - fp8_dtype_backward, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - use_split_accumulator=_2X_ACC_WGRAD, - ) - else: - grouped_gemm( - inputmats, - grad_output_mats, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - ) - else: # WGRAD - _, grad_biases, _ = grouped_gemm( + _, grad_biases_, _ = general_grouped_gemm( inputmats, - grad_output_mats, + grad_output, wgrad_list, ctx.activation_dtype, get_multi_stream_cublas_workspace(), layout="NT", grad=True, - use_bias=ctx.use_bias, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_wgrad_into_param_main_grad, ) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ # Deallocate input tensor clear_tensor_data(*inputmats) - clear_tensor_data(*inputmats_t) def handle_custom_ddp_from_mcore(w, wgrad): - if w.requires_grad: + if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): w.grad_added_to_main_grad = True if getattr(w, "zero_out_wgrad", False): @@ -478,22 +344,24 @@ def handle_custom_ddp_from_mcore(w, wgrad): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, - None, # m_splits - None, # use_bias - None, # is_first_microbatch - None, # fp8 - None, # fp8_calibration - None, # fp8_meta - None, # fuse_wgrad_accumulation - None, # cpu_offloading - None, # sequence_parallel - None, # activation_dtype - None, # fp8_meta_offsets + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, # is_grad_enabled None, # is_grad_enabled - None, # weights_fp8 *wgrad_list, *grad_biases, ) @@ -718,7 +586,7 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: + with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -727,29 +595,32 @@ def forward( w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors ] - weight_tensors_fp8 = [None] * self.num_gemms + input_quantizers, weight_quantizers, output_quantizers = ( + [None] * self.num_gemms, + [None] * self.num_gemms, + [None] * self.num_gemms, + ) + grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms if self.fp8: + input_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] for i in range(self.num_gemms): - if isinstance(weight_tensors[i], Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensors[i]._transpose is not None: - weight_tensors[i].transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_tensors_fp8[i] = self.get_fp8_workspace( - tensor=weight_tensors[i], - fp8_meta_forward=True, - fp8_meta_index=self._offsets["weight"] + i, - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) + input_quantizers[i].internal = True + weight_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["weight"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + weight_quantizers[i].internal = True + if torch.is_grad_enabled(): + grad_output_quantizers = [ + self.quantizers["scaling_bwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + grad_output_quantizers[i].internal = True if torch.is_grad_enabled(): linear_fn = _GroupedLinear.apply @@ -764,14 +635,17 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_output_quantizers, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.sequence_parallel, self.activation_dtype, - self._offsets, torch.is_grad_enabled(), - weight_tensors_fp8, + self, + skip_fp8_weight_update, *weight_tensors, *bias_tensors, ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 189464cf80..60c73a8d7d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -5,12 +5,14 @@ """LayerNormLinear API""" import os import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn import init -from .. import cpp_extensions as tex +import transformer_engine_torch as tex from .base import ( get_workspace, @@ -20,7 +22,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -40,14 +42,22 @@ _fsdp_scatter_tensors, _fsdp_gather_tensors, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import _apply_normalization, _noop_cat -from ..float8_tensor import Float8Tensor -from ..export import is_in_onnx_export_mode -from ..tensor import QuantizedTensor -from ..cpu_offload import is_cpu_offload_enabled +from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + +from ..cpp_extensions import ( + general_gemm, +) __all__ = ["LayerNormLinear"] @@ -64,15 +74,18 @@ def forward( ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], weight: torch.Tensor, - weight_fp8: Optional[torch.Tensor], bias: torch.Tensor, use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -87,13 +100,16 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, normalization: str, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_ag: bool, ub_name: str, - fp8_output: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible @@ -102,8 +118,7 @@ def forward( assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) + assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -111,205 +126,183 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_overlap_ag: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1 or (not is_grad_enabled): - ub_overlap_ag = False - if ub_overlap_ag: - dim_size = list(inputmat.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ub_name + "_fprop") - if return_layernorm_output: - # First prepare LN output in higher precision, - # which will be later copied to a FP8 UB - ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format) + tp_world_size = get_distributed_world_size(tp_group) + ub_overlap_ag_fprop = ( + ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output + ) + + weight_requires_grad = weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad + with_input_all_gather = parallel_mode == "column" and sequence_parallel + + if fp8: + if ( + any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + + # Configure quantizer for normalization output + with_quantized_norm = fp8 and not return_layernorm_output + if with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(input_quantizer, MXFP8Quantizer): + with_quantized_norm = False else: - ln_out = ub_obj_lnout.get_ubuf_output(0) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, + ) + + ub_obj_fprop = None + ln_out = None + if ub_overlap_ag_fprop: + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) + elif with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") else: - ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" ) - # Objects for FP8 cast - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out_scale_inv = None - if fp8: - ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) - - # Launch normalization kernel - ln_out, mu, rsigma = _apply_normalization( + # Apply normalization + ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, ln_weight, ln_bias, eps, - fp8 and not return_layernorm_output, - fp8_meta, + input_quantizer if with_quantized_norm else None, + inp.dtype, normalization, fwd_ln_sm_margin, zero_centered_gamma, - is_grad_enabled, - fp8_scale_inv=ln_out_scale_inv, ) - - # Column Parallel Linear - ln_out_gathered = False - ub_algo = None - if ub_overlap_ag: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - if not return_layernorm_output: - ln_out = torch.empty_like(ln_out) - if ub_obj_lnout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + ln_out_return = ln_out if return_layernorm_output else None + + # Prepare GEMM input + # Note: Cast to expected dtype and perform tensor-parallel communication + if with_input_all_gather and not ub_overlap_ag_fprop: + with_quantized_all_gather = fp8 + if return_layernorm_output and return_layernorm_output_gathered: + with_quantized_all_gather = False + if fp8: + input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=(input_quantizer if with_quantized_all_gather else None), + ) + if return_layernorm_output and return_layernorm_output_gathered: + ln_out_return = ln_out_total + if fp8 and not with_quantized_all_gather: + ln_out_total = input_quantizer(ln_out_total) + else: + if ub_overlap_ag_fprop: + ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif parallel_mode == "column" and sequence_parallel: - ln_out_gathered = True - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) + if fp8: + if not isinstance(ln_out, QuantizedTensor): + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + ln_out = input_quantizer(ln_out) + elif backward_needs_input: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + ln_out_total = ln_out + + # Cast weight to expected dtype + weightmat = weight + quantized_weight = False + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) else: - ln_out_total = ln_out + if not isinstance(weight, QuantizedTensor): + quantized_weight = True + + # Configure quantizer + if weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=True) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) - # If residual connection is after LN, we need `ln_out_return` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. - if return_layernorm_output: - ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(ln_out_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG if fp8: - if ub_overlap_ag: - ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0) - tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - out=ln_out_fp8, - scale_inv=ln_out_scale_inv, - ) - ln_out = torch.empty_like(ln_out_fp8) - else: - ln_out_total = tex.cast_to_fp8( - ln_out_total, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=ln_out_scale_inv, - ) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[slice_start:slice_end, ...] - else: - ln_out = ln_out_total - - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - - assert isinstance(weight_fp8, Float8Tensor) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - ln_out_scale_inv.fill_(ln_out_scale_inv.item()) - - if fp8_output: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) - else: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, - None, - None, - activation_dtype, - ) - out, _ = tex.fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ln_out_total, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - output_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=output_te_dtype, - ) - if output_dtype == torch.uint8: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) - else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - out, _, _ = tex.gemm( - weight, - ln_out_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ln_out_total = ub_obj.get_buffer(input_quantizer) + + out, *_, rs_out = general_gemm( + weightmat, + ln_out_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=activation_dtype, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, + ) + if not weight.requires_grad: + if not return_layernorm_output: + ln_out = ln_out_total = None + clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - ln_weight.weight_offloading = True - weight.weight_offloading = True + if fp8 and weightmat is not None: + set_offloading_param(weightmat, "weight_offloading", True) + set_offloading_param(ln_weight, "weight_offloading", True) + set_offloading_param(weight, "weight_offloading", True) - inputmat.activation_offloading = True - if normalization == "LayerNorm": - mu.activation_offloading = True - rsigma.activation_offloading = True - ln_out.activation_offloading = True + set_offloading_param(inputmat, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(rsigma, "activation_offloading", True) + set_offloading_param(ln_out, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -319,25 +312,34 @@ def forward( fsdp_group, mu, rsigma, - weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + weightmat if quantized_weight else None, ln_out if weight.requires_grad else None, ) - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, + weightmat, + weight, + bias, ln_weight, + ln_out, mu, rsigma, - weight, - weight_fp8, - weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, - ln_out if weight.requires_grad else None, - ln_out_scale_inv, ) - + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.requires_dgrad = inp.requires_grad + ctx.requires_wgrad = weight.requires_grad + ctx.quantized_weight = quantized_weight + if fuse_wgrad_accumulation and weight.requires_grad: + ctx.main_grad = weight.main_grad + ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.input_quantizer = input_quantizer + ctx.owns_input = inputmat is not inp + ctx.weight = weight ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -349,14 +351,13 @@ def forward( ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.return_layernorm_output = return_layernorm_output - ctx.return_layernorm_output_gathered = ( - return_layernorm_output_gathered and ln_out_gathered - ) + ctx.return_layernorm_output_gathered = return_layernorm_output_gathered ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -368,10 +369,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp_shape[1:-1], out_features) @@ -389,23 +393,42 @@ def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if isinstance(grad_outputs[0], Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ - 0 - ]._scale_inv with torch.cuda.nvtx.range("_LayerNormLinear_backward"): - ( + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, + weight, + _, + bias, ln_weight, + ln_out, mu, rsigma, - weight, - weight_fp8, - main_grad, - ln_out, - ln_out_scale_inv, - ) = ctx.saved_tensors + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -415,56 +438,93 @@ def backward( ctx.fsdp_shapes, mu, rsigma, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + weight if ctx.fp8 and ctx.quantized_weight else None, ln_out, ) + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) weight.main_grad = main_grad - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ctx.ub_name + "_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device + ) + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True) + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) + + if ctx.grad_output_quantizer is not None: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ( grad_output, - grad_output_c, - grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], ctx.parallel_mode == "row" + ctx, + grad_outputs[0], + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, ) - if ctx.ub_bulk_wgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_wgrad = False - - # Column Parallel Linear - # Overlap input AG with dgrad + # Prepare GEMM input + # Note: Perform tensor-parallel communication if needed + ln_out_total = None + ln_out_total_work = None if ( - weight.requires_grad - and (not ctx.ub_bulk_dgrad) + ctx.requires_wgrad and ctx.parallel_mode == "column" and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad ): - ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) + quantizer = None + if ctx.fp8: + quantizer = ctx.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) else: ln_out_total = ln_out - handle = None + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch @@ -472,218 +532,129 @@ def backward( else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - dgrad_size = list(grad_output.size()) - dgrad_size[1] = weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub(ctx.ub_name + "_wgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub(ctx.ub_name + "_dgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - else: - dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + # dgrad GEMM + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - rs_out = None - if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = weight.size(1) - rs_out = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=grad_output.device - ) - if ub_obj_dgrad.is_p2p_overlap(): - if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - out_index, meta_tensor, out_te_type, out_type = ( - None, - None, - None, - ctx.activation_dtype, - ) - if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT1 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) - - # DGRAD: Evaluated unconditionally to feed into Linear backward - _ = tex.fp8_gemm( - weight_fp8.transpose_2d(), - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ( - grad_output_c._data - if isinstance(grad_output_c, Float8Tensor) - else grad_output_c - ), - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out_type, - get_workspace(), - out=dgrad, - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_te_type, - ) - clear_tensor_data(grad_output_c) - else: - # DGRAD: Evaluated unconditionally to feed into Linear backward - _, _, _ = tex.gemm( - weight, - grad_output, - ctx.activation_dtype, - get_workspace(), - out=dgrad, - layout="NN", - grad=True, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - ) - if ctx.ub_bulk_dgrad: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if not ctx.ub_bulk_dgrad and handle is not None: - handle.wait() - if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + dgrad, *_ = general_gemm( + weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, + out_dtype=ctx.activation_dtype, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + + # Launch tensor-parallel communication + dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) - dgrad, handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + # Compute grad weight tensor wgrad = None - if weight.requires_grad: - if ctx.fp8: - # WGRAD - extra_output_tensor = None - if ctx.ub_bulk_wgrad: - if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output - extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=dgrad.device - ) - dgrad = extra_output_tensor + if ctx.requires_wgrad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) + if ctx.fp8: + # FP8 GEMM on Hopper only supports TN layout so the gathered input must have + # a valid transpose. + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) else: - dgrad = ub_obj_dgrad.get_ubuf_output(0) - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - wgrad, _ = tex.fp8_gemm( - ln_out_total_t, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - ( - grad_output_t._data - if isinstance(grad_output_t, Float8Tensor) - else grad_output_t - ), - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_t, grad_output_t) - else: - ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( - ln_out_total, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - wgrad, _, _ = tex.gemm( - ln_out_total_c, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_c) + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + ln_out_total._create_transpose() + else: - # WGRAD - wgrad, grad_bias, _ = tex.gemm( - ln_out_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device ) + + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + wgrad, grad_bias_, *_, rs_out = general_gemm( + ln_out_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + bias=(bias if (grad_bias is None and not ctx.fp8) else None), + out=main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) + + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out + else: + dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if not ctx.return_layernorm_output: + # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme clear_tensor_data(ln_out_total) - if ctx.ub_bulk_wgrad: - dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output - # Column Parallel Linear - if ( - (not ctx.ub_bulk_wgrad) - and ctx.parallel_mode == "column" - and ctx.tensor_parallel - and handle is not None - ): - handle.wait() + # Don't return grad bias if not needed + if not ctx.use_bias: + grad_bias = None - # LayerNorm gradient - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out.view(inputmat.shape) - else: - dgrad = dgrad.view(inputmat.shape) + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None # Residual gradient + dgrad = dgrad.view(inputmat.shape) if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + # Norm gradient dgamma = None dbeta = None if ctx.normalization == "LayerNorm": @@ -696,6 +667,7 @@ def backward( ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, ) + dgrad = dgrad.reshape(inputmat.size()) elif ctx.normalization == "RMSNorm": dgrad, dgamma = tex.rmsnorm_bwd( dgrad, @@ -705,14 +677,12 @@ def backward( ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, ) + dgrad = dgrad.reshape(inputmat.size()) dbeta = None clear_tensor_data(mu) clear_tensor_data(rsigma) - if not ctx.use_bias: - grad_bias = None - - if weight.requires_grad: + if ctx.requires_wgrad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): weight.grad_added_to_main_grad = True @@ -724,12 +694,7 @@ def backward( requires_grad=False, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + wgrad = None elif ctx.fuse_wgrad_accumulation: wgrad = None else: @@ -739,23 +704,26 @@ def backward( FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, Float8Tensor): - _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + # if ctx.fp8 and not isinstance(weight, QuantizedTensor): + # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, wgrad, - None, # weight_fp8 grad_bias, None, # use_bias None, # eps None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta None, # fuse_wgrad_accumulation + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -770,13 +738,16 @@ def backward( None, # bwd_ln_sm_margin None, # zero_centered_gamma None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad None, # ub_overlap_rs_dgrad - None, # ub_overlap_ag + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name - None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -887,10 +858,11 @@ def __init__( parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -907,13 +879,6 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_ag = ub_overlap_ag - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name if tp_group is None: self.tp_size = tp_size @@ -939,9 +904,49 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column-parallel overlaps + self.ub_overlap_ag_fprop = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_overlap_rs_dgrad = ( + ub_overlap_rs_dgrad and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_bulk_wgrad = ( + ub_bulk_wgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + + # Row-parallel overlaps + self.ub_overlap_rs_fprop = ( + ub_overlap_rs and self.sequence_parallel and self.parallel_mode == "row" + ) + self.ub_overlap_ag_dgrad = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "row" + ) + if any( + [ + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + ] + ): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + self.eps = eps layer_norm_weight = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_weight", @@ -950,7 +955,7 @@ def __init__( ) if self.normalization != "RMSNorm": layer_norm_bias = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) @@ -1034,7 +1039,9 @@ def __init__( # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) # Construct weight parameter self.register_parameter( @@ -1159,7 +1166,9 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch) as inp: + with self.prepare_forward( + inp, allow_non_contiguous=False # removed .contiguous from inside the layer + ) as inp: # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] @@ -1171,35 +1180,20 @@ def forward( ) else: unfused_weights = [w.dequantize() for w in unfused_weights] - weight_tensor = _noop_cat(unfused_weights) + + weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) else: bias_tensor = getattr(self, self.bias_names[0]) # Unused - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output) if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1212,15 +1206,18 @@ def forward( self.layer_norm_weight, self.layer_norm_bias, weight_tensor, - weight_fp8, bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1235,13 +1232,16 @@ def forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_ag, self.ub_name, - fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = fwd_fn(*args) @@ -1258,3 +1258,27 @@ def forward( if self.return_layernorm_output: return out, ln_out return out + + def _get_quantizers(self, fp8_output): + if not self.fp8: + return [None] * 5 + grad_input_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = False + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7bcbb1eb7d..88eebc8e6c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -5,12 +5,16 @@ """LayerNormMLP API""" import os import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn.parameter import Parameter from torch.nn import init +import transformer_engine_torch as tex + from .base import ( get_workspace, _ub_communicators, @@ -20,7 +24,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -35,6 +39,7 @@ assert_dim_for_fp8_exec, clear_tensor_data, requires_grad, + non_tn_fp8_gemm_supported, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -45,30 +50,39 @@ use_reentrant_activation_recompute, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, - _fsdp_gather_tensors, ) -from .. import cpp_extensions as tex - -from ..constants import dist_group_type, TE_DType +from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor -from ._common import _apply_normalization -from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ._common import apply_normalization, _fix_gathered_fp8_transpose +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..cpp_extensions import ( + general_gemm, +) __all__ = ["LayerNormMLP"] def _act_func(activation: str): funcs = { - "gelu": (tex.gelu, tex.dgelu), - "relu": (tex.relu, tex.drelu), - "geglu": (tex.geglu, tex.dgeglu), - "reglu": (tex.reglu, tex.dreglu), - "swiglu": (tex.swiglu, tex.dswiglu), - "qgelu": (tex.qgelu, tex.dqgelu), - "srelu": (tex.srelu, tex.dsrelu), + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -87,19 +101,24 @@ def forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, - fc1_weight_fp8: Optional[torch.Tensor], fc1_bias: torch.Tensor, use_fc1_bias: bool, fc2_weight: torch.Tensor, - fc2_weight_fp8: Optional[torch.Tensor], fc2_bias: torch.Tensor, use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + fc1_input_quantizer: Optional[Quantizer], + fc1_weight_quantizer: Optional[Quantizer], + fc2_input_quantizer: Optional[Quantizer], + fc2_weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_fc2_output_quantizer: Optional[Quantizer], + grad_fc1_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -108,7 +127,7 @@ def forward( activation_dtype: torch.dtype, return_layernorm_output: bool, return_layernorm_output_gathered: bool, - bias_gelu_nvfusion: bool, + bias_gelu_fusion: bool, set_parallel_mode: bool, is_grad_enabled: bool, fwd_ln_sm_margin: int, @@ -116,26 +135,34 @@ def forward( zero_centered_gamma: bool, activation: str, normalization: str, + ub_overlap_ag: bool, + ub_overlap_rs: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, gemm_gelu_fusion: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring + + in_features, inp_shape = ln_weight.numel(), inp.shape # Make sure input dimensions are compatible - in_features = ln_weight.numel() - inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(fc1_weight) - assert_dim_for_fp8_exec(fc2_weight) + assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + if ( + any([ub_overlap_ag, ub_overlap_rs]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) activation_func = _act_func(activation)[0] + device = inp.device # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -143,314 +170,250 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + # for standard fp8: layernorm output = FP8 + # only output of the linear is returned + # for return_layernorm_output: layernorm output = High precision, then cast to FP8 + # high precision layernorm output and output of the linear are returned + with_quantized_norm = fp8 and not return_layernorm_output + tp_world_size = get_distributed_world_size(tp_group) - if ub_overlap_ag: - if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_overlap_ag = False + ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output + ub_overlap_rs = ub_overlap_rs and is_grad_enabled + with_input_all_gather_nccl = sequence_parallel and not ub_overlap_ag + backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + + # Configure quantizer for normalization output + if fp8 and fc1_input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_quantized_norm: + if with_input_all_gather_nccl: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(fc1_input_quantizer, MXFP8Quantizer): + with_quantized_norm = False + else: + fc1_input_quantizer.set_usage( + rowwise=True, + columnwise=backwards_needs_fc1_input, + ) + + ub_obj_lnout = None + ln_out = None if ub_overlap_ag: ub_obj_lnout = get_ub("fc1_fprop") - ln_out = ub_obj_lnout.get_ubuf_output(0) - else: - ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype + ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) + elif not with_quantized_norm: ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" ) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs - - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out, mu, rsigma = _apply_normalization( + # Apply normalization + ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, ln_weight, ln_bias, eps, - fp8 and not return_layernorm_output, - fp8_meta, + fc1_input_quantizer if with_quantized_norm else None, + inp.dtype, normalization, fwd_ln_sm_margin, zero_centered_gamma, - is_grad_enabled, ) - # Column Parallel Linear + # Prepare GEMM input + # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False - ub_algo_ag = None - if ub_overlap_ag: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - ln_out = torch.empty_like(ln_out) - if ub_obj_lnout.is_atomic_gemm(): - ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif set_parallel_mode and sequence_parallel: + with_quantized_all_gather = fp8 + if with_input_all_gather_nccl: + if return_layernorm_output and return_layernorm_output_gathered: + with_quantized_all_gather = False + if fp8: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=(fc1_input_quantizer if with_quantized_all_gather else None), + ) ln_out_gathered = True - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: - ln_out_total = ln_out + with_quantized_all_gather = False + if ub_overlap_ag: + ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) + else: + ln_out_total = ln_out # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. + ln_out_return = None if return_layernorm_output: ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out - if fp8: - if ub_overlap_ag: - ln_out = tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if fp8 and not with_quantized_all_gather: + ln_out_total = fc1_input_quantizer(ln_out_total) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[ + slice_start:slice_end, ... + ] # TODO(pgadzinski) - check this # pylint: disable=fixme else: - ln_out_total = tex.cast_to_fp8( - ln_out_total, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[slice_start:slice_end, ...] - else: - ln_out = ln_out_total - - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias - fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias - - # Use FP8 weights - if fc1_weight_fp8 is None: - fc1_weight_fp8 = fc1_weight - if fc2_weight_fp8 is None: - fc2_weight_fp8 = fc2_weight - - assert isinstance(fc1_weight_fp8, Float8Tensor) - assert isinstance(fc2_weight_fp8, Float8Tensor) - - # Perform FP8 GEMM - fp8_gemm_args = [ - fc1_weight_fp8._data, - fc1_weight_fp8._scale_inv, - 0, - fc1_weight_fp8._fp8_dtype, - ln_out_total, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - activation_dtype, - get_workspace(), - ] - fp8_gemm_kwargs = { - "bias": fc1_bias, - "use_bias": use_fc1_bias, - "use_split_accumulator": _2X_ACC_FPROP, - "ub_algo": ub_algo_ag if ub_overlap_ag else None, - "ub": ub_obj_lnout if ub_overlap_ag else None, - "extra_output_tensor": ln_out if ub_overlap_ag else None, - } - if gemm_gelu_fusion: - fp8_gemm_args[8] = torch.uint8 # out_dtype - fp8_gemm_kwargs.update( - { - "gelu": True, - "out_index": tex.FP8FwdTensors.GEMM2_INPUT, - "fp8_meta_tensor": fp8_meta["scaling_fwd"], - "D_dtype": fp8_dtype_forward, - } + ln_out = ln_out_total + + # Cast weights to expected dtype + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + if not fp8: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) + else: + # If weights are not quantized, we call get_weight_workspace, + # which handles weight caching etc. + if not isinstance(fc1_weight, QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) - if not is_grad_enabled: - clear_tensor_data(ln_out_total) - - # Perform activation - if gemm_gelu_fusion: - gelu_out, fc1_out = fp8_gemm_out - else: - fc1_out, _ = fp8_gemm_out - gelu_out = activation_func( - fc1_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, + if not isinstance(fc2_weight, QuantizedTensor): + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - if not is_grad_enabled: - clear_tensor_data(fc1_out) - - fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( - None, - None, - None, - activation_dtype, - ) - rs_out = None - ub_algo_rs = None - if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) - dim_size = list(gelu_out.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc2_weight_fp8.size(0) - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - if ub_obj_fc2out.is_p2p_overlap(): - if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - - if ub_obj_fc2out.is_fp8_ubuf(): - fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT - fc2_meta_tensor = fp8_meta["scaling_fwd"] - fc2_te_type = fp8_dtype_forward - out_type = torch.uint8 - ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index]) + # Cast biases to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + if fc1_bias is not None: + fc1_bias = cast_if_needed(fc1_bias, bias_dtype) + if fc2_bias is not None: + fc2_bias = cast_if_needed(fc2_bias, bias_dtype) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if fc1_input_quantizer is not None: + fc1_input_quantizer.calibrate(ln_out_total) + if fc1_weight_quantizer is not None: + fc1_weight_quantizer.calibrate(fc1_weight) + + # FC1 GEMM + + # There are 2 fussions possible: + # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, + # - bias_gelu_fusion - only for full precision. + # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer + if activation != "gelu": + gemm_gelu_fusion = bias_gelu_fusion = False + else: + if fp8: + assert not bias_gelu_fusion, "Bias gelu fusion is supported only for full precision" else: - dim_size = list(gelu_out.size()) - dim_size[1] = fc2_weight_fp8.size(0) - fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - - _ = tex.fp8_gemm( - fc2_weight_fp8._data, - fc2_weight_fp8._scale_inv, - 0, - fc2_weight_fp8._fp8_dtype, - gelu_out, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - out_type, - get_workspace(), - bias=fc2_bias, - use_bias=use_fc2_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=fc2_out, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=fc2_out_index, - fp8_meta_tensor=fc2_meta_tensor, - D_dtype=fc2_te_type, - ) - if not is_grad_enabled: - clear_tensor_data(gelu_out) + gemm_gelu_fusion = True + if gemm_gelu_fusion and bias_gelu_fusion: + gemm_gelu_fusion = False + + fc1_outputs = general_gemm( + fc1_weight_final, + ln_out_total, + get_workspace(), + quantization_params=( + fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + ), + out_dtype=activation_dtype, + bias=( + fc1_bias if not bias_gelu_fusion else None + ), # otherwise bias is added later (fused with gelu) + gelu=gemm_gelu_fusion, + accumulate=_2X_ACC_FPROP, + ub=ub_obj_lnout, + ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, + ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): + clear_tensor_data(ln_out_total) + + # ACTIVATION - sometimes activation is fused with the GEMM above. + + fc1_out_without_bias = None + + if bias_gelu_fusion: + fc1_out = None + fc1_out_without_bias, *_ = fc1_outputs + act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) + elif gemm_gelu_fusion: + act_out, _, fc1_out, _ = fc1_outputs else: - # Cast for native AMP - fc1_weight = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight = cast_if_needed(fc2_weight, activation_dtype) - fc1_bias = cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias - fc2_bias = cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias - - if fp8_calibration: - # amax of fc1 input - amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of fc1 weight - amin, amax = fc1_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - fc1_outputs = tex.gemm( - fc1_weight, - ln_out_total, - activation_dtype, - get_workspace(), - bias=fc1_bias, - use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, - gelu=not bias_gelu_nvfusion and (activation == "gelu"), - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) - if not is_grad_enabled and not return_layernorm_output: - clear_tensor_data(ln_out_total) + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, fc2_input_quantizer) - if bias_gelu_nvfusion: - fc1_out, _, _ = fc1_outputs - gelu_out = bias_gelu_fused(fc1_out, fc1_bias) - else: - if activation == "gelu": - gelu_out, _, fc1_out = fc1_outputs - else: - fc1_out, _, _ = fc1_outputs - gelu_out = activation_func( - fc1_out, None, tex.FP8FwdTensors.GEMM2_INPUT, TE_DType[fc1_out.dtype] - ) - if not is_grad_enabled: - clear_tensor_data(fc1_out) - - if fp8_calibration: - # amax of fc2 input - amin, amax = gelu_out.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = torch.max( - -amin, amax - ).float() - # amax of fc2 weight - amin, amax = fc2_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = torch.max( - -amin, amax - ).float() - - if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) - dim_size = list(gelu_out.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc2_weight.size(0) - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - else: - dim_size = list(gelu_out.size()) - dim_size[1] = fc2_weight.size(0) - fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - _ = tex.gemm( - fc2_weight, - gelu_out, - activation_dtype, - get_workspace(), - bias=fc2_bias, - use_bias=use_fc2_bias, - out=fc2_out, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - ) - if not is_grad_enabled: - clear_tensor_data(gelu_out) + if not is_grad_enabled: + clear_tensor_data(fc1_out) + + if fp8_calibration: + fc2_input_quantizer.calibrate(act_out) + fc2_weight_quantizer.calibrate(fc2_weight) + + ub_obj_fc2out = None + rs_out = None + fc2_out = None + if ub_overlap_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + dim_size = list(act_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) + fc2_out = ub_obj_fc2out.get_buffer(output_quantizer) + else: + dim_size = list(act_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device) + + # FC2 GEMM + _ = general_gemm( + fc2_weight_final, + act_out, + get_workspace(), + out_dtype=activation_dtype, + bias=fc2_bias, + quantization_params=output_quantizer, + out=fc2_out, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj_fc2out, + ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, + extra_output=rs_out, + ) + if not is_grad_enabled: + clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) if is_grad_enabled: if cpu_offloading: - if fp8 and fc1_weight_fp8 is not None: - fc1_weight_fp8.weight_offloading = True - if fp8 and fc2_weight_fp8 is not None: - fc2_weight_fp8.weight_offloading = True - ln_weight.weight_offloading = True - fc1_weight.weight_offloading = True - fc2_weight.weight_offloading = True - if fc1_bias is not None: - fc1_bias.weight_offloading = True - - inputmat.activation_offloading = True - if normalization == "LayerNorm": - mu.activation_offloading = True - rsigma.activation_offloading = True - ln_out.activation_offloading = True - fc1_out.activation_offloading = True - gelu_out.activation_offloading = True + if fp8 and fc1_weight_final is not None: + set_offloading_param(fc1_weight_final, "weight_offloading", True) + if fp8 and fc2_weight_final is not None: + set_offloading_param(fc2_weight_final, "weight_offloading", True) + set_offloading_param(ln_weight, "weight_offloading", True) + set_offloading_param(fc1_weight, "weight_offloading", True) + set_offloading_param(fc2_weight, "weight_offloading", True) + set_offloading_param(fc1_bias, "weight_offloading", True) + + set_offloading_param(inputmat, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(rsigma, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(ln_out, "activation_offloading", True) + set_offloading_param(fc1_out, "activation_offloading", True) + set_offloading_param(fc1_out_without_bias, "activation_offloading", True) + set_offloading_param(act_out, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -461,45 +424,68 @@ def forward( mu, rsigma, ln_out, - fc1_out, - gelu_out, - fc1_weight_fp8 if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + fc1_out_without_bias if bias_gelu_fusion else fc1_out, + act_out, + fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) - ctx.save_for_backward( + if not fc1_weight.requires_grad: + if not return_layernorm_output: + clear_tensor_data(ln_out) + ln_out = None + if not fc2_weight.requires_grad: + clear_tensor_data(act_out) + act_out = None + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, + ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer + fc1_weight_final, + fc1_bias, + fc1_out, + fc1_out_without_bias, + act_out, + fc2_weight_final, + fc2_bias, mu, rsigma, - ln_out if fc1_weight.requires_grad else None, - fc1_out, - gelu_out if fc2_weight.requires_grad else None, - fc1_weight, - fc1_weight_fp8, - fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, - fc2_weight, - fc2_weight_fp8, - fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, - fc1_bias, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) + if fuse_wgrad_accumulation: + ctx.fc1_main_grad = fc1_weight.main_grad if fc1_weight.requires_grad else None + ctx.fc2_main_grad = fc2_weight.main_grad if fc2_weight.requires_grad else None + + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer + ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc1_input_quantizer = fc1_input_quantizer + + ctx.fc1_weight_requires_grad = fc1_weight.requires_grad + ctx.fc2_weight_requires_grad = fc2_weight.requires_grad + ctx.fc1_weight = fc1_weight + ctx.fc2_weight = fc2_weight + + ctx.device = device ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_fc1_bias = use_fc1_bias ctx.use_fc2_bias = use_fc2_bias + ctx.use_bias = ctx.use_fc1_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape ctx.tp_group = tp_group ctx.tp_size = tp_size - ctx.bias_gelu_nvfusion = bias_gelu_nvfusion + ctx.bias_gelu_fusion = bias_gelu_fusion ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output_gathered = ( return_layernorm_output_gathered and ln_out_gathered @@ -511,7 +497,10 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag - ctx.requires_dgrad = inp.requires_grad + + ctx.requires_dgrad = ( + inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad + ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( @@ -547,499 +536,366 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_LayerNormMLP_backward"): - ( + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, ln_weight, - mu, - rsigma, ln_out, - fc1_out, - gelu_out, fc1_weight, - fc1_weight_fp8, - fc1_weight_main_grad, - fc2_weight, - fc2_weight_fp8, - fc2_weight_main_grad, fc1_bias, - fwd_scale_inverses, - ) = ctx.saved_tensors - - # Gather saved autograd context tensors when running with FSDP - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already - # shards/unshards the base weights so we don't do it ourselves - _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, + fc1_out, + fc1_out_without_bias, + act_out, + fc2_weight, + fc2_bias, mu, rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + fc1_weight_main_grad = ( + ctx.fc1_main_grad + if fc1_weight is not None + and ctx.fuse_wgrad_accumulation + and ctx.fc1_weight_requires_grad + else None + ) + fc2_weight_main_grad = ( + ctx.fc2_main_grad + if fc2_weight is not None + and ctx.fuse_wgrad_accumulation + and ctx.fc2_weight_requires_grad + else None ) - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad) - fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad) - + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. + if ctx.fuse_wgrad_accumulation: fc1_weight.main_grad = fc1_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad - activation_func = _act_func(ctx.activation)[1] - - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not fc1_weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub("fc1_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - if ctx.ub_overlap_ag: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_ag = False + # TODO: Fix this # pylint: disable=fixme + # Gather saved autograd context tensors when running with FSDP + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + # _fsdp_gather_tensors( + # ctx.fsdp_group, + # ctx.fsdp_shapes, + # mu, + # rsigma, + # ln_out, + # fc1_out_without_bias if bias_gelu_nvfusion else fc1_out,, + # gelu_out, + # fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + # fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + # ) + + # No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required + ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + if ctx.grad_fc2_output_quantizer is not None: + ctx.grad_fc2_output_quantizer.set_usage( + rowwise=True, + columnwise=True, + ) - ub_algo = None + ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - dim_size = list(grad_outputs[0].size()) - dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub("fc2_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess + ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, - grad_output_c, - grad_output_t, fc2_bias_grad, - ) = TransformerEngineBaseModule.grad_output_preprocess(ctx, grad_outputs[0], True) - - if ctx.ub_bulk_wgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not fc1_weight.requires_grad: - ctx.ub_bulk_wgrad = False - # Column Parallel Linear - # Overlap input AG with dgrad + ) = TransformerEngineBaseModule.grad_output_preprocess( + ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ) + + # Prepare FC1 GEMM input + # Note: Perform tensor-parallel communication if needed + ln_out_total = None + ln_out_total_work = None if ( - fc1_weight.requires_grad - and (not ctx.ub_bulk_dgrad) - and ctx.set_parallel_mode + ctx.fc1_weight_requires_grad + and ctx.tensor_parallel and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad ): - ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) + quantizer = None + if ctx.fp8: + quantizer = ctx.fc1_input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) else: ln_out_total = ln_out - handle = None + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + # There are 5 possible fusion paths + # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, + # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize + # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize + # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize + # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + fc2_dgrad_gemm_gelu_fusion = ( + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + ) - fc2_wgrad = None - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - # FC2 DGRAD; Unconditional - fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_fp8.transpose_2d(), - fc2_weight_fp8._scale_inv, - 0, - fc2_weight_fp8._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, - ) - if ctx.ub_overlap_ag: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - clear_tensor_data(grad_output_c) - - # FC2 WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if fc2_weight.requires_grad: - gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) - clear_tensor_data(gelu_out) - fc2_wgrad, _ = tex.fp8_gemm( - gelu_out_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ) - clear_tensor_data(gelu_out_t, grad_output_t) - - if ctx.activation == "gelu": - fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( - fc2_dgrad, - fc1_out, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - else: - dgelu = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_fused( - dgelu, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - clear_tensor_data(fc1_out) - else: - if fc2_weight.requires_grad: - gelu_out_c = torch.ops.tex_ts.cast_from_fp8_ts( - gelu_out, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - clear_tensor_data(gelu_out) - fc2_wgrad, _, _ = tex.gemm( - gelu_out_c, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=False, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - clear_tensor_data(gelu_out_c) - - if ctx.activation == "gelu": - fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( - fc2_dgrad, fc1_out, fc1_bias - ) - else: - dgelu_no_fp8 = activation_func( - fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype] - ) - fc1_bias_grad = dgelu_no_fp8.sum(dim=0) - clear_tensor_data(fc1_out) - - dgelu = tex.cast_to_fp8( - dgelu_no_fp8, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - dgelu_t = None + # FC2 DGRAD; Unconditional + gemm_output, *_ = general_gemm( + fc2_weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=( + ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ), # high precision to activation + out_dtype=ctx.activation_dtype, + gelu=fc2_dgrad_gemm_gelu_fusion, + gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_fc2_dgrad, + ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None, + ) + if fc2_dgrad_gemm_gelu_fusion: + dact = gemm_output + fc2_dgrad = None + else: + fc2_dgrad = gemm_output - out_index, meta_tensor, out_te_type, out_type = ( - None, - None, - None, - ctx.activation_dtype, - ) - fc1_dgrad_size = list(dgelu.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - # Get/alloc fc1_dgrad - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - else: - fc1_dgrad = torch.empty( - fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device - ) + # FC2 WGRAD + if ctx.fc2_weight_requires_grad: + if isinstance(act_out, QuantizedTensor): + act_out.update_usage(rowwise_usage=True, columnwise_usage=True) - # FP8 RS - if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT2 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + if isinstance(grad_output, QuantizedTensor): + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap - rs_out = None - if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(dgelu.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight_fp8.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) - if ub_obj_dgrad.is_p2p_overlap(): - if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - # FC1 DGRAD: Unconditional - _ = tex.fp8_gemm( - fc1_weight_fp8.transpose_2d(), - fc1_weight_fp8._scale_inv, - 0, - fc1_weight_fp8._fp8_dtype, - dgelu, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - out_type, - get_workspace(), - out=fc1_dgrad, - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_te_type, - ) - else: - # FC2 DGRAD; Unconditional - fc2_dgrad, _, _ = tex.gemm( - fc2_weight, + fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( + act_out, grad_output, - ctx.activation_dtype, get_workspace(), - layout="NN", - gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == "gelu"), + out_dtype=ctx.activation_dtype, + quantization_params=None, # wgrad in high precision + layout="NT", grad=True, - gelu_input=fc1_out, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, + accumulate=accumulate_wgrad_into_param_main_grad, + use_split_accumulator=_2X_ACC_WGRAD, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) + if fc2_bias_grad is None: + fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ + clear_tensor_data(act_out) + + # bias computation + fc1_bias_grad = None + fuse_gemm_and_bias_fc1_wgrad = False + if ctx.grad_fc1_output_quantizer is not None: + ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.bias_gelu_fusion: + # Fusion: gemm, bias + gelu + assert ctx.activation == "gelu" + assert not ctx.fp8 + fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) + if ctx.grad_fc1_output_quantizer is not None: + dact = ctx.grad_fc1_output_quantizer(dact) + elif _act_func(ctx.activation)[2] is not None and ctx.fp8: + # Fusion: gemm, bias + gelu + quantize + dbias_dact_quantize_func = _act_func(ctx.activation)[2] + fc1_bias_grad, dact = dbias_dact_quantize_func( + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + ) # quantize bgrad gelu fused + else: + # Fusion: gemm + gelu, + if not fc2_dgrad_gemm_gelu_fusion: + activation_func_bwd = _act_func(ctx.activation)[1] + dact = activation_func_bwd( + fc2_dgrad, fc1_out.to(ctx.activation_dtype), None + ) # activation in high precision - # FC2 WGRAD - if fc2_weight.requires_grad: - fc2_wgrad, fc2_bias_grad, _ = tex.gemm( - gelu_out, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_fc2_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - clear_tensor_data(gelu_out) - - if ctx.bias_gelu_nvfusion and ctx.activation == "gelu": - fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) - else: - if ctx.activation != "gelu": - fc2_dgrad = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - - # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM - # and will not be calculated in case wgrad is not required. - if not fc1_weight.requires_grad: - fc1_bias_grad = fc2_dgrad.sum(dim=0) - - # Overwrite data. Deleting the tensor does not release underlying memory. - clear_tensor_data(fc1_out) - dgelu = fc2_dgrad - - fc1_dgrad_size = list(dgelu.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + if ctx.fp8: + fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) else: - fc1_dgrad = torch.empty( - fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + fuse_gemm_and_bias_fc1_wgrad = ( + True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 ) + # it may not be calculated in case wgrad is not required. + if fc1_bias is not None: + if not ctx.fc1_weight_requires_grad and fc1_bias.requires_grad: + fc1_bias_grad = dact.sum(dim=0) + + # Overwrite data. Deleting the tensor does not release underlying memory. + clear_tensor_data(fc1_out, fc1_out_without_bias) + + # Set UB algo and UB obj for fc1_dgrad/wgrad bulk/pipelined overlap + ub_obj_fc1_dgrad = None + ub_obj_fc1_wgrad = None + ub_type_fc1_dgrad = None + fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] + fc1_dgrad_rs_out = None + fc1_dgrad_bulk = None + if ctx.ub_overlap_rs_dgrad: + # Overlap DGRAD+RS + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.RS + fc1_dgrad_rs_out = torch.empty( + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + ) - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + else: if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(dgelu.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) - if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - # FC1 DGRAD: Unconditional - _ = tex.gemm( - fc1_weight, - dgelu, - ctx.activation_dtype, - get_workspace(), - out=fc1_dgrad, - layout="NN", - grad=True, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - ) + # Overlap ln_out all-gather with DGRAD compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.AG + ub_obj_fc1_dgrad.copy_into_buffer( + ln_out, ctx.fc1_input_quantizer, local_chunk=True + ) + + if ctx.ub_bulk_wgrad: + # Overlap FC1 DGRAD reduce-scatter with WGRAD compute + ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) + + # FC1 DGRAD: Unconditional + fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( + fc1_weight, + dact, + get_workspace(), + out=fc1_dgrad_bulk, + out_dtype=ctx.activation_dtype, + layout="NN", + grad=True, + ub=ub_obj_fc1_dgrad, + ub_type=ub_type_fc1_dgrad, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) - if ctx.ub_bulk_dgrad: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad - if ctx.set_parallel_mode and ctx.sequence_parallel: - if not ctx.ub_bulk_dgrad and handle is not None: - handle.wait() - if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + fc1_dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + fc1_dgrad = fc1_dgrad_rs_out + elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) - fc1_dgrad, handle = reduce_scatter_along_first_dim( - fc1_dgrad, ctx.tp_group, async_op=True + fc1_dgrad, fc1_dgrad_work = reduce_scatter_along_first_dim( + fc1_dgrad, + ctx.tp_group, + async_op=True, ) - elif ctx.set_parallel_mode and ctx.tensor_parallel: - fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel: + fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + # FC1 WGRAD fc1_wgrad = None - if fc1_weight.requires_grad: - if ctx.fp8: - # FC1 WGRAD - extra_output_tensor = None - if ctx.ub_bulk_wgrad: - if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output - extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device - ) - fc1_dgrad = extra_output_tensor - else: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - fc1_wgrad, _ = tex.fp8_gemm( - ln_out_total_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - dgelu_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_t, dgelu_t) - else: - ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( - ln_out_total, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - fc1_wgrad, _, _ = tex.gemm( - ln_out_total_c, - dgelu_no_fp8, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_c, dgelu_no_fp8) + if ctx.fc1_weight_requires_grad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer) + if ctx.fp8: + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + ln_out_total._create_transpose() + else: - # FC1 WGRAD - fc1_wgrad_outputs = tex.gemm( - ln_out_total, - dgelu, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=not ctx.bias_gelu_nvfusion, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + # Make sure GEMM inputs have expected data + if isinstance(ln_out_total, QuantizedTensor): + ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True) + if isinstance(dact, QuantizedTensor): + dact.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad_rs_out = torch.empty( + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" ) - clear_tensor_data(ln_out_total, dgelu) - if ctx.bias_gelu_nvfusion: - fc1_wgrad, _, _ = fc1_wgrad_outputs - else: - fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs - if ctx.ub_bulk_wgrad: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + fc1_wgrad_outputs = general_gemm( + ln_out_total, + dact, + get_workspace(), + out_dtype=ctx.activation_dtype, + layout="NT", + grad=fuse_gemm_and_bias_fc1_wgrad, + bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_fc1_wgrad, + ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) - # Column Parallel Linear - if ( - (not ctx.ub_bulk_wgrad) - and ctx.set_parallel_mode - and ctx.tensor_parallel - and handle is not None - ): - handle.wait() + clear_tensor_data(ln_out_total, dact) - # LayerNorm gradient - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out.view(inputmat.shape) - else: - dgrad = fc1_dgrad.view(inputmat.shape) + if fuse_gemm_and_bias_fc1_wgrad: + fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs + else: + fc1_wgrad, *_ = fc1_wgrad_outputs + + if ctx.ub_bulk_wgrad: + if ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad = fc1_dgrad_rs_out + else: + fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True) + + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if fc1_dgrad_work is not None: + fc1_dgrad_work.wait() + fc1_dgrad_work = None # Residual gradient + dgrad = fc1_dgrad.view(inputmat.shape) if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + # Norm gradient dgamma = None dbeta = None if ctx.normalization == "LayerNorm": @@ -1062,10 +918,9 @@ def backward( ctx.zero_centered_gamma, ) dbeta = None - clear_tensor_data(mu) - clear_tensor_data(rsigma) + clear_tensor_data(mu, rsigma) - if fc1_weight.requires_grad: + if ctx.fc1_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): fc1_weight.grad_added_to_main_grad = True @@ -1077,18 +932,13 @@ def backward( requires_grad=False, ) else: - fc1_wgrad = torch.empty( - fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + fc1_wgrad = None elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: fc1_wgrad = None - if fc2_weight.requires_grad: + if ctx.fc2_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): fc2_weight.grad_added_to_main_grad = True @@ -1100,12 +950,7 @@ def backward( requires_grad=False, ) else: - fc2_wgrad = torch.empty( - fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + fc2_wgrad = None elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: @@ -1114,34 +959,37 @@ def backward( if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + # FIX THIS # Scatter Fp8 tranposed-weight buffers - if ctx.fp8: - _fsdp_scatter_tensors( - ctx.fsdp_group, - fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, - ) - + # if ctx.fp8: + # _fsdp_scatter_tensors( + # ctx.fsdp_group, + # fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, + # fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, + # ) return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, fc1_wgrad, - None, # fc1_weight_fp8 - # Due to bias gelu nvfusion available in the bf16 case, fc1_bias_grad is calculated at - # different paths and this confused the linter. - fc1_bias_grad if ctx.use_fc1_bias else None, # pylint: disable=used-before-assignment + fc1_bias_grad if ctx.use_fc1_bias else None, None, # use_fc1_bias - fc2_wgrad, - None, # fc2_weight_fp8 + fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_bias_grad if ctx.use_fc2_bias else None, None, # use_fc2_bias None, # eps None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta None, # fuse_wgrad_accumulation + None, # fc1_input_quantizer + None, # fc1_weight_quantizer + None, # fc2_input_quantizer + None, # fc2_weight_quantizer + None, # output_quantizer + None, # grad_fc2_output_quantizer + None, # grad_fc1_output_quantizer + None, # grad_input_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -1150,7 +998,7 @@ def backward( None, # activation_dtype None, # return_layernorm_output None, # return_layernorm_output_gathered - None, # bias_gelu_nvfusion + None, # bias_gelu_fusion None, # set_parallel_mode None, # is_grad_enabled None, # fwd_ln_sm_margin @@ -1158,13 +1006,15 @@ def backward( None, # zero_centered_gamma None, # activation None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad - None, # ub_overlap_rs_dgrad - None, # ub_overlap_rs None, # ub_overlap_ag + None, # ub_overlap_rs + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # gemm_gelu_fusion None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -1285,11 +1135,11 @@ def __init__( set_parallel_mode: bool = False, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ) -> None: super().__init__() @@ -1308,11 +1158,7 @@ def __init__( ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag + # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) @@ -1337,6 +1183,16 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.size_per_partition = divide(ffn_hidden_size, self.tp_size) + self.ub_overlap_ag = ub_overlap_ag and self.sequence_parallel + self.ub_overlap_rs = ub_overlap_rs and self.sequence_parallel + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad and self.sequence_parallel + self.ub_bulk_wgrad = ( + ub_bulk_wgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad + ) + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1357,7 +1213,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["reglu", "geglu", "swiglu"]: + if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1491,61 +1347,30 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + with self.prepare_forward(inp, num_gemms=2) as inp: + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers() # Get weight tensors fc1_weight = self.fc1_weight - fc1_bias = self.fc1_bias + fc1_bias = self.fc1_bias if self.use_bias else None fc2_weight = self.fc2_weight - fc2_bias = self.fc2_bias + fc2_bias = self.fc2_bias if self.use_bias else None if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.from_float8() if isinstance(fc2_weight, Float8Tensor): fc2_weight = fc2_weight.from_float8() - # Cast weights to FP8 if needed - fc1_weight_fp8 = None - fc2_weight_fp8 = None - if self.fp8: - update_workspace = is_first_microbatch is None or is_first_microbatch - if isinstance(fc1_weight, Float8Tensor): - if fc1_weight._transpose is not None: - fc1_weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - cache_name = None - if is_first_microbatch is not None: - cache_name = "fc1_weight" - fc1_weight_fp8 = self.get_fp8_workspace( - tensor=fc1_weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=cache_name, - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - if isinstance(fc2_weight, Float8Tensor): - if fc2_weight._transpose is not None: - fc2_weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - cache_name = None - if is_first_microbatch is not None: - cache_name = "fc2_weight" - fc2_weight_fp8 = self.get_fp8_workspace( - tensor=fc2_weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, - cache_name=cache_name, - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): self.bias_gelu_nvfusion = False @@ -1561,19 +1386,24 @@ def forward( self.layer_norm_weight, self.layer_norm_bias, fc1_weight, - fc1_weight_fp8, fc1_bias, self.use_bias, fc2_weight, - fc2_weight_fp8, fc2_bias, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1582,7 +1412,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion, + self.bias_gelu_nvfusion and not self.fp8, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1590,13 +1420,15 @@ def forward( self.zero_centered_gamma, self.activation, self.normalization, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_rs, self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.gemm_gelu_fusion, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = fwd_fn(*args) @@ -1613,3 +1445,48 @@ def forward( if self.return_layernorm_output: return out, ln_out return out + + def _get_quantizers(self): + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) = [None] * 8 + if self.fp8: + fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + fc1_input_quantizer.internal = False # temporary + fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + fc1_weight_quantizer.internal = True + fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + fc2_input_quantizer.set_usage( + rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + ) + fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + fc2_weight_quantizer.internal = True + if torch.is_grad_enabled(): + grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ] + grad_fc2_output_quantizer.internal = True + grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ] + grad_fc1_output_quantizer.internal = True + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] + grad_input_quantizer.internal = True + + return ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5893c4ea3c..460ce87bc6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,9 +3,9 @@ # See LICENSE for license information. """Linear API""" +from typing import Callable, Dict, Optional, Tuple, Union from functools import reduce from operator import mul as multiply_op -from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -19,15 +19,15 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import _noop_cat -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ._common import noop_cat, _fix_gathered_fp8_transpose +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, init_method_constant, requires_grad, + non_tn_fp8_gemm_supported, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -35,23 +35,25 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) from ..cpp_extensions import ( - fp8_gemm, - gemm, - fp8_cast_transpose_fused, - cast_to_fp8, + general_gemm, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor -from ..export import is_in_onnx_export_mode -from ..tensor import QuantizedTensor -from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param __all__ = ["Linear"] @@ -64,15 +66,17 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: Union[Float8Tensor, torch.Tensor], - weight_fp8: Optional[Float8Tensor], + weight: torch.Tensor, inp: torch.Tensor, - bias: torch.Tensor, - use_bias: bool, + bias: Optional[torch.Tensor], is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -89,293 +93,186 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, - fp8_output: bool, + fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - is_input_fp8 = isinstance(inp, Float8Tensor) # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" - inputmat = inp.view(-1, in_features) - if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop - ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop - - # Cast input to expected dtype - inputmat = cast_if_needed(inputmat, activation_dtype) - inputmat_t = None - inputmat_no_fp8 = inputmat - inputmat_scale_inv = None - - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if isinstance(inputmat, Float8Tensor): - inputmat_scale_inv = inputmat._scale_inv - else: - inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat, inputmat_t = fp8_cast_transpose_fused( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) - else: - # FP8 input for forward - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) - - # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop: - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) - else: - inputmat_total = inputmat - + backward_needs_input = is_grad_enabled and weight.requires_grad + + # Prepare input tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + inputmat = inp + inputmat_total = None + with_input_all_gather_nccl = ( + parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop + ) + own_quantized_input = False if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - - assert isinstance(weight_fp8, Float8Tensor) - - if fp8_output: - out_index, meta_tensor, out_tedtype, out_pttype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) - else: - out_index, meta_tensor, out_tedtype, out_pttype = ( - None, - None, - None, - activation_dtype, + if ( + any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" ) - ub_obj = None - ub_algo = None - rs_out = None - inputmat_data = ( - inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total - ) - if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") - out = ub_obj.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj.is_p2p_overlap(): - if ub_obj.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj.is_fp8_ubuf(): - out_index = tex.FP8FwdTensors.GEMM1_OUTPUT - meta_tensor = fp8_meta["scaling_fwd"] - out_tedtype = fp8_dtype_forward - out_pttype = torch.uint8 - ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) - - elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") - assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer." - ub_obj.copy_input_to_ubuf(inputmat_data, True) - ub_obj.set_ubuf_scale_inv(inputmat_scale_inv) - if ub_obj.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - out_tedtype = TE_DType[activation_dtype] - out_pttype = activation_dtype - dim_size = list(inputmat_total.size()) - dim_size[0] *= tp_size - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) - + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_input_all_gather_nccl: + assert not isinstance( + inputmat, QuantizedTensor + ), "All gather of fp8 input is not supported" + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, + ) else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) - - _ = fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - inputmat_data, - inputmat_scale_inv, - 0, - fp8_dtype_forward, - out_pttype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=out, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_tedtype, - ) - if fp8_output: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, ) + if not isinstance(inputmat, QuantizedTensor): + inputmat = input_quantizer(inputmat) + elif backward_needs_input: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) + inputmat_total = inputmat else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = inputmat_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - ub_obj = None - ub_algo = None - rs_out = None - if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") - out = ub_obj.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - - elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") - ub_obj.copy_input_to_ubuf(inputmat_total, True) - dim_size = list(inputmat_total.size()) - dim_size[0] *= tp_size # all-gathered sequence length - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - + inputmat = cast_if_needed(inp, activation_dtype) + if with_input_all_gather_nccl: + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + inputmat_total = inputmat - _ = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - out=out, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - ) + # Cast weight to expected dtype + weightmat = weight + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) + else: + if not isinstance(weight, QuantizedTensor): + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) + + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(inputmat_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + out_dtype = activation_dtype + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG + if fp8: + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) + inputmat_total = ub_obj.get_buffer(input_quantizer) + + out, *_, rs_out = general_gemm( + weightmat, + inputmat_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=out_dtype, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, + ) if is_grad_enabled: saved_inputmat = None - saved_inputmat_t = None - if weight.requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t is None: - saved_inputmat = inputmat - else: - saved_inputmat_t = inputmat_t - if cpu_offloading: - saved_inputmat_t.activation_offloading = True - else: - saved_inputmat = inputmat_no_fp8 + if backward_needs_input: + if own_quantized_input and isinstance(inputmat, QuantizedTensor): + inputmat.update_usage(rowwise_usage=False) + saved_inputmat = inputmat - if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - weight.weight_offloading = True - - if saved_inputmat is not None: - saved_inputmat.activation_offloading = True + if cpu_offloading: + set_offloading_param(weight, "weight_offloading", True) + set_offloading_param(weightmat, "weight_offloading", True) + if saved_inputmat is not None: + set_offloading_param(saved_inputmat, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, - saved_inputmat, # None if fp8 == False - saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled - weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + saved_inputmat, + weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, ) - ctx.save_for_backward( + # TODO(ksivamani): Check memory usage + tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, - saved_inputmat_t, - inputmat_scale_inv, + weightmat, weight, - weight_fp8, - weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, + bias, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta + ctx.input_quantizer = input_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if fuse_wgrad_accumulation and weight.requires_grad: + ctx.main_grad = weight.main_grad + ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias + ctx.use_bias = bias is not None ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape @@ -388,8 +285,10 @@ def forward( ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad - ctx.is_input_fp8 = is_input_fp8 + ctx.requires_wgrad = weight.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False + ctx.owns_input = saved_inputmat is not inp + ctx.is_input_fp8 = not own_quantized_input if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -397,34 +296,53 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if parallel_mode == "row": - if ub_overlap_rs_fprop: - out = rs_out - elif sequence_parallel: + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: out, _ = allreduce(out, tp_group) - # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp_shape[1:-1], out_features) + out = out.view(-1, *inp_shape[1:-1], out_features) + return out @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if isinstance(grad_output, Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ] = grad_output._scale_inv with torch.cuda.nvtx.range("_Linear_backward"): - ( - inputmat, - inputmat_t, - inputmat_scale_inv, - weight, - weight_fp8, - main_grad, - ) = ctx.saved_tensors + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + saved_tensors = ctx.saved_tensors + inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + restore_from_saved(ctx.tensor_objects, saved_tensors) + ) + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + weight = torch.nn.Parameter(weight, weight.requires_grad) + weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -433,105 +351,89 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.fsdp_group, ctx.fsdp_shapes, inputmat, - inputmat_t, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + weight_fp8, ) - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) - weight.main_grad = main_grad - - tp_world_size = get_distributed_world_size(ctx.tp_group) - ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad - ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad - ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad - ctx.ub_obj_gradout = None + ub_obj_dgrad = None ub_obj_wgrad = None - ub_algo_wgrad = None - ub_algo_dgrad = None - rs_out = None - dgrad = None + ub_type_dgrad = None + ub_type_wgrad = None dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - dgrad = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device - ) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - dgrad = ctx.ub_obj_gradout.get_ubuf_output(1) - if ctx.ub_obj_gradout.is_p2p_overlap(): - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS rs_out = torch.empty( dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device ) - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - inputmat_data = ( - inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat - ) - ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) - inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) - if isinstance(inputmat, Float8Tensor): - inputmat._data = inputmat_ubuf - else: - inputmat = inputmat_ubuf + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True) if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") - dgrad = ub_obj_wgrad.get_ubuf_output(1) - + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + if ctx.grad_output_quantizer is not None: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ( grad_output, - grad_output_c, - grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_output, ctx.parallel_mode == "row" + ctx, + grad_output, + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, ) - # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers) + # Prepare input tensor + # Note: Perform tensor-parallel communication if needed inputmat_total = None - inputmat_t_total = None - inputmat_gather_handle = None + inputmat_total_work = None if ( - weight.requires_grad + ctx.requires_wgrad and ctx.parallel_mode == "column" and ctx.sequence_parallel and not ctx.ub_bulk_dgrad ): - inputmat_total, inputmat_gather_handle = gather_along_first_dim( - inputmat, ctx.tp_group, async_op=ctx.requires_dgrad + quantizer = None + if ctx.fp8: + quantizer = ctx.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + inputmat_total, inputmat_total_work = gather_along_first_dim( + inputmat, + ctx.tp_group, + async_op=True, + quantizer=quantizer, ) else: inputmat_total = inputmat - inputmat_t_total = inputmat_t + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch @@ -539,185 +441,132 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - output_dtype = ctx.activation_dtype + # Compute grad input tensor + dgrad = None + dgrad_work = None if ctx.requires_dgrad: - if ctx.fp8: - if ctx.is_input_fp8 or ( - ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf() - ): - out_index, meta_tensor, output_te_dtype, output_dtype = ( - tex.FP8BwdTensors.GRAD_INPUT1, - ctx.fp8_meta["scaling_bwd"], - fp8_dtype_backward, - torch.uint8, + + # Update quantizer + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + # dgrad GEMM + dgrad, *_, rs_out = general_gemm( + weight_fp8, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, + out_dtype=ctx.activation_dtype, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + + # Launch tensor-parallel communication + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, ) - if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf(): - ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, - None, - None, - ctx.activation_dtype, - ) + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - if dgrad is None: - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - dgrad_shape[0] = dgrad_shape[0] * tp_world_size - dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device) - - if ctx.requires_dgrad: - if ctx.fp8: - _ = fp8_gemm( - weight_fp8.transpose_2d(), - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - output_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo_dgrad, - ub=ctx.ub_obj_gradout, - out=dgrad, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=output_te_dtype, - extra_output_tensor=rs_out, - ) + # Compute grad weight tensor + wgrad = None + if ctx.requires_wgrad: + if ctx.ub_bulk_dgrad: + inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) + if ctx.fp8: + if inputmat._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + inputmat_total = _fix_gathered_fp8_transpose( + inputmat_total, ctx.tp_size + ) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + inputmat_total._create_transpose() - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out - elif output_dtype == torch.uint8: - dgrad = Float8Tensor( - data=dgrad, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1, - fp8_dtype=fp8_dtype_backward, - dtype=ctx.activation_dtype, - ) else: - _ = gemm( - weight, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NN", - grad=True, - ub_algo=ub_algo_dgrad, - ub=ctx.ub_obj_gradout, - out=dgrad, - extra_output_tensor=rs_out, + if inputmat_total_work is not None: + # Synchronize tensor-parallel communication + inputmat_total_work.wait() + inputmat_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device ) - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out - - if inputmat_gather_handle is not None: - inputmat_gather_handle.wait() - - # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers) - dgrad_reduce_handle = None - if ctx.requires_dgrad and ctx.parallel_mode == "column": - if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad): - dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True - ) - elif ctx.tensor_parallel and not ctx.sequence_parallel: - dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + wgrad, grad_bias_, _, rs_out = general_gemm( + inputmat_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + bias=(bias if (grad_bias is None and not ctx.fp8) else None), + out=main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) - wgrad = None - if weight.requires_grad: - if ctx.fp8: - # WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if ctx.ub_overlap_ag: - if isinstance(grad_output_c, Float8Tensor): - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - if inputmat_t_total is None: - if isinstance(inputmat_total, Float8Tensor): - inputmat_t_total = inputmat_total.transpose_2d() - else: - inputmat_t_total = tex.fp8_transpose( - inputmat_total, fp8_dtype_backward - ) - wgrad, _ = fp8_gemm( - ( - inputmat_t_total._data - if isinstance(inputmat_t_total, Float8Tensor) - else inputmat_t_total - ), - inputmat_scale_inv, - 0, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub=ub_obj_wgrad, - ub_algo=ub_algo_wgrad, - ) + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out else: - wgrad, _, _ = gemm( - inputmat_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub=ub_obj_wgrad, - ub_algo=ub_algo_wgrad, - ) - else: - # WGRAD - wgrad, grad_bias, _ = gemm( - inputmat_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub=ub_obj_wgrad, - ub_algo=ub_algo_wgrad, - ) + dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) - if ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_ubuf_output(0) + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ # Deallocate input tensor - clear_tensor_data(inputmat_total) - clear_tensor_data(inputmat_t_total) - - # Wait for dgrad reduce-scatter or all-reduce - if dgrad_reduce_handle is not None: - dgrad_reduce_handle.wait() + if ctx.owns_input: + clear_tensor_data(inputmat_total) + # Don't return grad bias if not needed if not ctx.use_bias: grad_bias = None - if weight.requires_grad: + # Synchronize tensor parallel communication + if inputmat_total_work is not None: + inputmat_total_work.wait() + inputmat_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None + + if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): + if ( + ctx.fuse_wgrad_accumulation + and weight is not None + and hasattr(weight, "grad_added_to_main_grad") + ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): wgrad = torch.zeros( @@ -727,12 +576,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], requires_grad=False, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + wgrad = None elif ctx.fuse_wgrad_accumulation: wgrad = None else: @@ -742,19 +586,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, Float8Tensor): + if ctx.fp8 and not isinstance(weight, QuantizedTensor): _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) - return ( wgrad, - None, # weight_fp8 dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, - None, # use_bias None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -773,6 +618,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_name None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -865,6 +712,7 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, @@ -878,6 +726,8 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias + self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -903,17 +753,32 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel # Column parallel TP overlap options - self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag - self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs - self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad - self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad - if self.ub_overlap_rs_dgrad: - self.ub_bulk_dgrad = False - self.ub_bulk_wgrad = False + self.ub_overlap_ag_fprop = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag + ) + self.ub_overlap_rs_dgrad = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_dgrad + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_wgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_wgrad + and not self.ub_overlap_rs_dgrad + ) # Row parallel TP overlap options - self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs - self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag + self.ub_overlap_rs_fprop = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs + ) + self.ub_overlap_ag_dgrad = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag + ) if any( [ @@ -928,19 +793,6 @@ def __init__( assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." self.ub_name = ub_name - assert not ( - self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop - ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." - assert not ( - self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad - ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." - assert not ( - self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) - ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time." - - self.get_rng_state_tracker = get_rng_state_tracker - self.rng_tracker_name = rng_tracker_name - # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1017,7 +869,9 @@ def __init__( # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) # Construct weight parameter self.register_parameter( @@ -1084,6 +938,7 @@ def forward( inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -1115,7 +970,6 @@ def forward( with self.prepare_forward( inp, - is_first_microbatch, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: @@ -1129,36 +983,25 @@ def forward( ) else: unfused_weights = [w.dequantize() for w in unfused_weights] - weight_tensor = _noop_cat(unfused_weights) + weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused - - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=self.fsdp_group, - ) + bias_tensor = None + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output, fp8_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor): + weight_tensor._quantizer = weight_quantizer if torch.is_grad_enabled(): linear_fn = _Linear.apply @@ -1168,14 +1011,16 @@ def forward( args = [None] args += ( weight_tensor, - weight_fp8, inp, - bias_tensor, - self.apply_bias and not self.gemm_bias_unfused_add, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1194,12 +1039,38 @@ def forward( self.ub_name, fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = linear_fn(*args) - if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) if self.return_bias: return out, cast_if_needed(bias_tensor, self.activation_dtype) return out + + def _get_quantizers(self, fp8_output, fp8_grad): + if not self.fp8: + return [None] * 5 + grad_input_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = False + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 26bceab737..bb826e552e 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -11,11 +11,11 @@ from transformer_engine_torch import FP8TensorMeta from ..fp8 import FP8GlobalStateManager -from ..tensor import Float8Tensor +from ..tensor.float8_tensor import Float8Tensor from ..utils import ( - canonicalize_device, # pylint: disable=unused-import - canonicalize_dtype, # pylint: disable=unused-import - devices_match, # pylint: disable=unused-import + canonicalize_device, + canonicalize_dtype, + devices_match, ) @@ -61,12 +61,9 @@ def convert_tensor( # Note: torch.Tensor.to ignores memory_format kwarg (see # https://github.com/pytorch/pytorch/issues/132020). data = data.contiguous(memory_format=memory_format) - return Float8Tensor.make_like( - tensor, - data=data, - fp8_attrs=tensor._fp8_attrs, - dtype=dtype, - ) + out = Float8Tensor.make_like(tensor, dtype=dtype) + out.data = data + return out # Convert standard PyTorch tensor tensor = tensor.to(device=device, dtype=dtype) @@ -85,46 +82,14 @@ def reshape( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor | Float8Tensor: - """Reshape tensor, keeping same data if possible - - If the input is a Float8Tensor, this function attempts to preserve - the cached transpose if available and valid. If a cached transpose - is present, it is interpreted as the transpose of a 2D matrix - where the width matches the innermost tensor dimension. - - """ - - # Make sure tensor is in expected format + """Reshape tensor, keeping same data if possible""" tensor = convert_tensor( tensor, device=device, dtype=dtype, memory_format=torch.contiguous_format, ) - - # Return immediately if tensor already has desired shape - shape = list(shape) - if len(shape) == tensor.dim(): - if sum(1 for d in shape if d == -1) > 1: - raise ValueError( - "Attempted to reshape tensor with " - f"shape={tuple(tensor.size())} into shape={tuple(shape)}" - ) - if all(d1 == d2 for d1, d2 in zip(shape, tensor.size()) if d1 != -1): - return tensor - - # Reshape FP8 tensor - # Note: Preserve cached transpose if possible - if is_float8_tensor(tensor): - out = Float8Tensor.make_like( - tensor, - data=tensor._data.view(shape), - fp8_attrs=tensor._fp8_attrs, - ) - return out - - # Reshape standard PyTorch tensor - return tensor.view(shape) + return tensor.reshape(*shape) def maybe_autocast_dtype( diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 7ad6e70929..45c78bea87 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -10,20 +10,12 @@ import torch -import transformer_engine_torch -from ...constants import TE_DType -from ...cpp_extensions import ( - geglu as tex_geglu, - gelu as tex_gelu, - reglu as tex_reglu, - relu as tex_relu, - swiglu as tex_swiglu, - fp8_dswiglu_cast_transpose_fused, -) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +import transformer_engine_torch as tex +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor from ...utils import clear_tensor_data, devices_match from ..op import BasicOperation, OperationContext +from .._common import reshape class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): @@ -93,43 +85,23 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - with_fp8_output = False - output_fp8_meta = None - output_dtype = TE_DType[dtype] - output_fp8_scale_inv = None - if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: - with_fp8_output = True - fp8_meta = next_op.get_fp8_meta("input") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - output_fp8_meta = fp8_meta[fp8_meta_key] - output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0: + quantizer = next_op.get_quantizer("forward", 0) + else: + quantizer = None # Launch kernel y = self._activation_forward_impl( - x, - output_fp8_meta, - 0, - output_dtype, - scale_inv=output_fp8_scale_inv, + reshape(x, (-1, x.size(-1))), + quantizer, ) # Check output tensor if y.dim() != x.dim(): y = y.reshape(list(x.shape[:-1]) + [-1]) - if with_fp8_output: - y = Float8Tensor( - data=y, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=output_dtype, - fp8_scale_inv=output_fp8_scale_inv, - dtype=dtype, - ) # Save state for backward pass - ctx.save_for_backward(x) + ctx.save_for_backward(x.detach()) ctx.fp8_enabled = fp8_enabled ctx.prev_op = prev_op @@ -154,7 +126,11 @@ def op_backward( dy = dy.contiguous() # Launch kernel - dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) + dx = self._activation_backward_impl( + reshape(dy, (-1, dy.size(-1))), + reshape(x, (-1, x.size(-1))), + None, + ) # Check grad input tensor if dx.size() != x.size(): @@ -181,10 +157,10 @@ class GELU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_gelu(*args, **kwargs) + return tex.gelu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dgelu(*args, **kwargs) + return tex.dgelu(*args, **kwargs) class ReLU(_ActivationOperation): @@ -197,10 +173,10 @@ class ReLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_relu(*args, **kwargs) + return tex.relu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.drelu(*args, **kwargs) + return tex.drelu(*args, **kwargs) class GEGLU(_ActivationOperation): @@ -232,10 +208,10 @@ class GEGLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_geglu(*args, **kwargs) + return tex.geglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dgeglu(*args, **kwargs) + return tex.dgeglu(*args, **kwargs) class ReGLU(_ActivationOperation): @@ -261,10 +237,10 @@ class ReGLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_reglu(*args, **kwargs) + return tex.reglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dreglu(*args, **kwargs) + return tex.dreglu(*args, **kwargs) class SwiGLU(_ActivationOperation): @@ -299,92 +275,7 @@ class SwiGLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_swiglu(*args, **kwargs) + return tex.swiglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dswiglu(*args, **kwargs) - - def op_backward( - self, - ctx: OperationContext, - grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, tuple[()]]: - - # Saved tensors from forward pass - (x,) = ctx.saved_tensors - - # Tensor attributes - dtype = x.dtype - device = x.device - - # Check grad output tensor - dy = grad_output - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() - if not devices_match(dy.device, device) or dy.dtype != dtype: - dy = dy.to(device=device, dtype=dtype) - if not dy.is_contiguous(): - dy = dy.contiguous() - - # Check if FP8 is enabled - with_fp8_grad_input = False - grad_input_fp8_meta = None - grad_input_dtype = TE_DType[dtype] - grad_input_fp8_scale_inv = None - if ( - ctx.fp8_enabled - and ctx.prev_op is not None - and ctx.prev_op.num_fp8_scales("grad_output") > 0 - ): - with_fp8_grad_input = True - fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - grad_input_fp8_meta = fp8_meta[fp8_meta_key] - grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) - - # Launch kernel - if with_fp8_grad_input: - # Fused with FP8 cast-transpose - input_dims = x.size() - flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] - flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] - dx = torch.empty(input_dims, dtype=torch.uint8, device=device) - dx_t = torch.empty( - (flat_input_dims[1], flat_input_dims[0]), - dtype=torch.uint8, - device=device, - ) - fp8_dswiglu_cast_transpose_fused( - dy.reshape(flat_output_dims), - x.reshape(flat_input_dims), - grad_input=dx.reshape(flat_input_dims), - grad_input_transpose=dx_t, - otype=grad_input_dtype, - fp8_meta=grad_input_fp8_meta, - fp8_meta_index=0, - scale_inv=grad_input_fp8_scale_inv, - ) - dx = Float8Tensor( - data=dx, - fp8_meta=grad_input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=grad_input_dtype, - fp8_scale_inv=grad_input_fp8_scale_inv, - dtype=dtype, - ) - dx._transpose = dx_t - dx._transpose_invalid = False - else: - # Standard impl - dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) - if dx.size() != x.size(): - dx = dx.reshape(x.size()) - - # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) - # # Clear input tensor if possible - # if ctx.prev_op is not None: - # clear_tensor_data(x) - - return dx, () + return tex.dswiglu(*args, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index 2dd1d1b75e..15b1f65d85 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -9,12 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import convert_tensor, is_float8_tensor +from ...distributed import gather_along_first_dim +from ...tensor import QuantizedTensor +from ..op import BasicOperation, OperationContext class AllGather(BasicOperation): @@ -45,47 +42,12 @@ def op_forward( prev_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: - - # Trivial case + out: torch.Tensor if self.process_group_size == 1: - return input_ - - # Tensor dimensions - input_dims = input_.size() - if not input_dims: - raise RuntimeError( - "Attempted to all-gather a tensor " - f"with shape={list(input_dims)} " - f"over {self.process_group_size} processes" - ) - output_dims = list(input_dims) - output_dims[0] *= self.process_group_size - - # Perform all-gather - x = convert_tensor(input_, memory_format=torch.contiguous_format) - y = None - if is_float8_tensor(x): - y = Float8Tensor.make_like( - x, - data=torch.empty( - output_dims, - dtype=torch.uint8, - device=x.device, - ), - ) - torch.distributed.all_gather_into_tensor( - y._data, - x._data, - group=self.process_group, - ) + out = input_.detach() else: - y = torch.empty(output_dims, dtype=x.dtype, device=x.device) - torch.distributed.all_gather_into_tensor( - y, - x, - group=self.process_group, - ) - return y + out, _ = gather_along_first_dim(input_, self.process_group) + return out def op_backward( self, @@ -110,8 +72,8 @@ def op_backward( # Check output gradient tensor dy = grad_output - if is_float8_tensor(dy): - dy = dy.from_float8() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() dy = dy.contiguous() # Perform reduce-scatter diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index c5178d2d91..892e120da1 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,33 +12,24 @@ import torch -from transformer_engine.pytorch.cpp_extensions import ( - FP8TensorMeta, - fp8_gemm, - gemm, -) -from transformer_engine.pytorch.distributed import ( +from transformer_engine.pytorch.module.base import get_workspace +from ...cpp_extensions import general_gemm +from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - get_fp8_te_dtype, -) -from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +from ...fp8 import FP8GlobalStateManager +from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD +from ...tensor import Quantizer, QuantizedTensor +from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ..op import BasicOperation, OperationContext from .._common import ( canonicalize_device, canonicalize_dtype, - convert_tensor, devices_match, - is_float8_tensor, - reshape, ) from ...utils import clear_tensor_data @@ -110,17 +101,8 @@ def __init__( self.in_features: int = in_features self.out_features: int = out_features - # Weight tensor device - defer_param_init = False + # Weight tensor attributes device = canonicalize_device(device) - if device.type == "meta": - defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device - - # Weight tensor datatype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") @@ -147,16 +129,14 @@ def __init__( out_features=out_features, ) - # Whether weight tensor is natively in FP8 - self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters() - if self._with_fp8_parameters: - self._fp8_metas = self._make_fp8_metas() + # Whether weight tensor is natively quantized + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() # Initialize parameters if needed weight = torch.empty( self.local_out_features, self.local_in_features, - device="meta", + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -164,7 +144,7 @@ def __init__( self.register_parameter("weight", weight) self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] self._rng_state_tracker_function = rng_state_tracker_function - if not defer_param_init: + if weight.device.type != "meta": self.reset_parameters() # Whether to accumulate weight gradient into main_grad @@ -273,43 +253,48 @@ def _canonicalize_tensor_parallelism( local_out_features, ) - def num_fp8_scales(self, mode: str) -> int: - if mode in ("input", "param", "grad_output"): + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 + if mode == "backward": return 1 return 0 def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda" or is_float8_tensor(weight): - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Allocate buffer if needed + if isinstance(weight, QuantizedTensor): + weight = torch.empty( + weight.size(), + dtype=weight.dtype, + device=device, + ) + elif not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values - init_context = contextlib.nullcontext + init_context = contextlib.nullcontext() if self._rng_state_tracker_function is not None: - init_context = self._rng_state_tracker_function().fork - with init_context(): + init_context = self._rng_state_tracker_function().fork() + with init_context: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - # Cast to FP8 if needed - if self._with_fp8_parameters: - dummy_amax = torch.empty( - (1, 1), - dtype=torch.float32, - device=self.device, - ) # Dummy buffer to avoid overwriting amax history - weight = Float8Tensor.to_float8( - weight, - fp8_meta=self.get_fp8_meta("param"), - fp8_meta_forward=True, - fp8_meta_index=0, - amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), + # Quantize if needed + if self._with_quantized_weight: + quantizer = self.get_quantizer("forward", 1) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled(), ) + with torch.no_grad(): + weight = quantizer(weight) # Save updated parameter if not isinstance(weight, torch.nn.Parameter): @@ -318,8 +303,33 @@ def reset_parameters(self) -> None: def pre_forward(self, *args, **kwargs) -> None: super().pre_forward(*args, **kwargs) - if self.weight.device.type == "meta": + + # Initialize weights if needed + weight = self.weight + if weight.device.type == "meta": self.reset_parameters() + weight = self.weight + + # Configure quantizers + if FP8GlobalStateManager.is_fp8_enabled(): + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + + # Specify required tensor formats + is_grad_enabled = torch.is_grad_enabled() + weight_requires_grad = is_grad_enabled and weight.requires_grad + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if isinstance(weight_quantizer, Float8Quantizer) and isinstance( + weight, Float8TensorBase + ): + weight._quantizer = weight_quantizer @staticmethod def _functional_forward( @@ -327,17 +337,17 @@ def _functional_forward( weight: torch.Tensor, *, bias: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, + device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, out: Optional[torch.Tensor] = None, accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, - with_fp8_compute: bool = False, - input_fp8_meta: Optional[dict[str, Any]] = None, - weight_fp8_meta: Optional[dict[str, Any]] = None, - output_fp8_meta: Optional[dict[str, Any]] = None, + with_quantized_compute: bool = False, + input_quantizer: Optional[Quantizer] = None, + weight_quantizer: Optional[Quantizer] = None, + output_quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Functional API for forward pass @@ -366,16 +376,14 @@ def _functional_forward( parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_fp8_compute: bool, default = `False` - Whether to perform compute in FP8 - input_fp8_meta: dict, optional - FP8 metadata for casting input tensor to FP8. Required for - FP8 compute if input is not already in FP8. - weight_fp8_meta: dict, optional - FP8 metadata for casting weight tensor to FP8. Required for - FP8 compute if weight is not already in FP8. - output_fp8_meta: dict, optional - FP8 metadata for casting output tensor to FP8 + with_quantized_compute: bool, default = `False` + Whether to perform compute with quantized data. + input_quantizer: Quantizer, optional + Builder class for quantized input tensor. + weight_quantizer: Quantizer, optional + Builder class for quantized weight tensor. + output_quantizer: Quantizer, optional + Builder class for quantized output tensor. Returns ------- @@ -390,17 +398,6 @@ def _functional_forward( """ - # Check device - if device is None: - device = weight.device if out is None else out.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - if out is not None and not devices_match(out.device, device): - raise ValueError( - f"Output tensor has invalid device (expected {device}, got {out.device})" - ) - # Check datatype if dtype is None: dtype = weight.dtype if out is None else out.dtype @@ -410,36 +407,88 @@ def _functional_forward( if out is not None and out.dtype != dtype: raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})") - # Check input tensor dims - input_dims = tuple(input.size()) - weight_dims = tuple(weight.size()) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - - # Check output tensor dims - output_dims: list[int] - if out is None: - output_dims = list(input_dims) - output_dims[0] = -1 - output_dims[-1] = weight_dims[0] + # Check input tensor + x_local = input + x = None + x_async = None + with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + own_quantized_x_local = False + if with_quantized_compute: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(rowwise=True) + if with_x_all_gather: + input_quantizer.set_usage(columnwise=False) + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + quantizer=input_quantizer, + ) + else: + if not isinstance(x_local, QuantizedTensor): + x_local = input_quantizer(x_local) + own_quantized_x_local = True + x = x_local else: - output_dims = list(out.size()) - if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]: + if isinstance(x_local, QuantizedTensor): + x_local = x_local.dequantize() + if x_local.dtype != dtype: + x_local = x_local.to(dtype=dtype) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + else: + x = x_local + + # Check weight tensor + w = weight + w_is_quantized = isinstance(w, QuantizedTensor) + if with_quantized_compute and not w_is_quantized: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(rowwise=True) + w = weight_quantizer(w) + elif not with_quantized_compute and w_is_quantized: + w = w.dequantize() + if not with_quantized_compute and w.dtype != dtype: + w = w.to(dtype=dtype) + + # Check output tensor + y = out + if y is None: + if not with_quantized_compute: + output_quantizer = None + if tensor_parallel_mode == "row": + output_quantizer = None + elif isinstance(y, QuantizedTensor): + if not with_quantized_compute: + raise ValueError("Output tensor is quantized, but quantized compute is not enabled") + if tensor_parallel_mode == "row": raise ValueError( - f"Output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" + "Output tensor is quantized, " + "but row tensor parallelism does not support quantized output" ) + if output_quantizer is None: + output_quantizer = getattr(y, "_quantizer", None) + if output_quantizer is None: + raise ValueError("Output tensor is quantized, but quantizer was not provided") + else: + output_quantizer = None + if isinstance(output_quantizer, MXFP8Quantizer): + raise RuntimeError( + "Attempting to generate MXFP8 output tensor, " + "but GEMM with MXFP8 output is not supported" + ) + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor if accumulate_into_out: - if out is None: + if y is None: raise ValueError( "Attempted to accumulate into output tensor without providing output tensor" ) @@ -448,181 +497,22 @@ def _functional_forward( "Accumulating into output tensor is not supported with row tensor parallelism" ) - # Check if FP8 is enabled - if with_fp8_compute: - if input_fp8_meta is None and not is_float8_tensor(input): - raise ValueError("No FP8 metadata was provided for casting input to FP8") - if weight_fp8_meta is None and not is_float8_tensor(weight): - raise ValueError("No FP8 metadata was provided for casting weight to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" - if out is None: - with_fp8_output = with_fp8_output and output_fp8_meta is not None - else: - if is_float8_tensor(out): - if not with_fp8_output: - raise ValueError( - "Output tensor is a Float8Tensor, but FP8 output is not supported" - ) - out._reset_caches() - else: - with_fp8_output = False - - # Check input tensor - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - with_transpose_cache = weight.requires_grad - if tensor_parallel_mode == "column" and sequence_parallel: - with_transpose_cache = False - x_local = Float8Tensor.to_float8( - x_local, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=with_transpose_cache, - ) - elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.dequantize() - x = x_local + # Synchronize communication for input + _wait_async(x_async) x_async = None - if tensor_parallel_mode == "column" and sequence_parallel: - x, x_async = gather_along_first_dim( - x_local, - tensor_parallel_group, - async_op=True, - ) - - # Check weight tensor - w = convert_tensor( - weight, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if with_fp8_compute and not is_float8_tensor(w): - fp8_dtype = get_fp8_te_dtype( - weight_fp8_meta["recipe"], - fprop_tensor=True, - ) - w = Float8Tensor.to_float8( - w, - fp8_meta=weight_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) - elif not with_fp8_compute and is_float8_tensor(w): - w = w.dequantize() - - # Check bias tensor - b = None - if bias is not None: - b = convert_tensor( - bias, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - - # Construct output tensor - y = None - if out is not None: - y = reshape(out, (-1, output_dims[-1])) - elif with_fp8_output: - fp8_dtype = get_fp8_te_dtype( - output_fp8_meta["recipe"], - fprop_tensor=True, - ) - data = torch.empty( - (x.size(0), weight_dims[0]), - dtype=torch.uint8, - device=device, - ) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - y = torch.empty( - (x.size(0), weight_dims[0]), - dtype=dtype, - device=device, - ) # Perform GEMM - _wait_async(x_async) - x_async = None - if with_fp8_compute: - kwargs = { - "accumulate": accumulate_into_out, - "out": y, - "bias": b, - "use_bias": (b is not None), - } - if with_fp8_output: - if y._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = y._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) - fp8_meta.scale_inv = y._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=y._fp8_meta_forward, - ) - fp8_meta = y._fp8_meta[fp8_meta_key] - fp8_meta_index = y._fp8_meta_index - kwargs.update( - { - "out": y._data, - "out_index": fp8_meta_index, - "fp8_meta_tensor": fp8_meta, - "D_dtype": y._fp8_dtype, - } - ) - fp8_gemm( - w._data, - w._scale_inv, - 0, - w._fp8_dtype, - x._data, - x._scale_inv, - 0, - x._fp8_dtype, - y.dtype, - get_workspace(), - **kwargs, - ) - else: - gemm( - w, - x, - y.dtype, - get_workspace(), - accumulate=accumulate_into_out, - out=y, - bias=b, - use_bias=(b is not None), - ) + y, *_ = general_gemm( + w, + x, + get_workspace(), + out_dtype=dtype, + quantization_params=output_quantizer, + accumulate=accumulate_into_out, + out=y, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ) # Reduce tensor-parallel output if needed if tensor_parallel_mode == "row": @@ -631,23 +521,29 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Reshape output tensor if needed - if out is None: - out = reshape(y, output_dims) + # Configure input tensor for backward pass + if own_quantized_x_local: + ### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme + # x_local.update_usage(rowwise_usage=False) + pass + + # Detach input tensor if needed + # Note: PyTorch autograd produces esoteric errors if we save + # input tensor as context for backward pass. + if x_local is input: + x_local = x_local.detach() - return out, x_local, w + return y, x_local, w @staticmethod def _functional_backward( grad_output: torch.Tensor, input: Optional[torch.Tensor], # pylint: disable=redefined-builtin weight: Optional[torch.Tensor], - input_dims: Iterable[int], - weight_dims: Iterable[int], *, input_requires_grad: bool = True, weight_requires_grad: bool = True, - device: Optional[torch.device] = None, + device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, grad_weight: Optional[torch.Tensor] = None, accumulate_into_grad_weight: bool = False, @@ -656,11 +552,11 @@ def _functional_backward( tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, - with_fp8_compute: bool = False, - input_fp8_meta: Optional[dict[str, Any]] = None, - weight_fp8_meta: Optional[dict[str, Any]] = None, - grad_output_fp8_meta: Optional[dict[str, Any]] = None, - grad_input_fp8_meta: Optional[dict[str, Any]] = None, + with_quantized_compute: bool = False, + input_quantizer: Optional[Quantizer] = None, + weight_quantizer: Optional[Quantizer] = None, + grad_output_quantizer: Optional[Quantizer] = None, + grad_input_quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Functional API for backward pass @@ -674,10 +570,6 @@ def _functional_backward( weight: torch.Tensor, optional Weight tensor. Required to compute loss gradient w.r.t. input. - input_dims: iterable of int - Input tensor dimensions - weight_dims: iterable of int - Weight tensor dimensions input_requires_grad: bool Whether to compute loss gradient w.r.t. input tensor weight_requires_grad: bool @@ -703,21 +595,18 @@ def _functional_backward( parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_fp8_compute: bool, default = `False` - Whether to perform compute in FP8 - input_fp8_meta: dict, optional - FP8 metadata for casting input tensor to FP8. Required for - FP8 compute if input is not already in FP8. - weight_fp8_meta: dict, optional - FP8 metadata for casting weight tensor to FP8. Required for - FP8 compute if weight is not already in FP8. - grad_output_fp8_meta: dict, optional - FP8 metadata for casting loss gradient w.r.t. output - tensor to FP8. Required if output grad is not already in - FP8. - grad_input_fp8_meta: dict, optional - FP8 metadata for casting loss gradient w.r.t. input - tensor to FP8 + with_quantized_compute: bool, default = `False` + Whether to perform compute with quantized data. + input_quantizer: Quantizer, optional + Builder class for quantized input tensor. + weight_quantizer: Quantizer, optional + Builder class for quantized weight tensor. + grad_output_quantizer: Quantizer, optional + Builder class for quantized loss gradient w.r.t. output + tensor. + grad_input_quantizer: dict, optional + Builder class for quantized loss gradient w.r.t. input + tensor. Returns ------- @@ -728,13 +617,6 @@ def _functional_backward( """ - # Check device - if device is None: - device = weight.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - # Check datatype if dtype is None: dtype = weight.dtype @@ -742,109 +624,42 @@ def _functional_backward( if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") - # Check tensor dims - output_dims = tuple(grad_output.size()) - input_dims = tuple(input_dims) - weight_dims = tuple(weight_dims) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if weight_dims[0] != output_dims[-1]: - raise ValueError( - f"Grad output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if grad_input is not None and tuple(grad_input.size()) != input_dims: - raise ValueError( - f"Grad input tensor (shape={tuple(grad_input.size())}) " - f"does not match expected shape ({input_dims})" - ) - - # Check grad input tensor - if not input_requires_grad: - grad_input = None - if grad_input is not None and not devices_match(grad_input.device, device): - raise ValueError( - f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})" - ) - if grad_input is not None and grad_input.dtype != dtype: - raise ValueError( - f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})" + # Check grad output tensor + dy_local = grad_output + dy = None + dy_async = None + with_dy_all_gather = tensor_parallel_mode == "row" and sequence_parallel + if with_quantized_compute: + if grad_output_quantizer is None: + raise ValueError("Missing quantizer for grad output tensor") + grad_output_quantizer.set_usage( + rowwise=input_requires_grad, + columnwise=weight_requires_grad, ) - if accumulate_into_grad_input: - if grad_input is None: - raise ValueError( - "Attempted to accumulate into grad input tensor " - "without providing grad input tensor" - ) - if tensor_parallel_mode == "column": - raise ValueError( - "Accumulating into grad input tensor " - "is not supported with column tensor parallelism" + if with_dy_all_gather: + dy, dy_async = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + quantizer=grad_output_quantizer, ) - - # Check if FP8 is enabled - if with_fp8_compute: - if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): - raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - with_fp8_grad_input = ( - with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column" - ) - if grad_input is None: - with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None + else: + if not isinstance(dy_local, QuantizedTensor): + dy_local = grad_output_quantizer(dy_local) + dy = dy_local else: - if is_float8_tensor(grad_input): - if not with_fp8_grad_input: - raise ValueError( - "Grad input tensor is a Float8Tensor, but FP8 output is not supported" - ) - grad_input._reset_caches() + if isinstance(dy_local, QuantizedTensor): + dy_local = dy_local.dequantize() + if dy_local.dtype != dtype: + dy_local = dy_local.to(dtype=dtype) + if with_dy_all_gather: + dy, dy_async = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + ) else: - with_fp8_grad_input = False - - # Check grad output tensor - dy_async = None - dy = reshape( - grad_output, - (-1, output_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(dy): - fp8_dtype = get_fp8_te_dtype( - grad_output_fp8_meta["recipe"], - fprop_tensor=False, - ) - with_transpose_cache = weight_requires_grad - if tensor_parallel_mode == "row" and sequence_parallel: - with_transpose_cache = False - dy = Float8Tensor.to_float8( - dy, - fp8_meta=grad_output_fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=with_transpose_cache, - ) - elif not with_fp8_compute and is_float8_tensor(dy): - dy = dy.dequantize() - if tensor_parallel_mode == "row" and sequence_parallel: - dy, dy_async = gather_along_first_dim( - dy, - tensor_parallel_group, - async_op=True, - ) + dy = dy_local # Check input tensor x = None @@ -852,35 +667,36 @@ def _functional_backward( if weight_requires_grad: if input is None: raise ValueError("Input tensor is required to compute weight grad") - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - x_local = Float8Tensor.to_float8( - x_local, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=(not x_is_sharded), - ) - elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.from_float8() - x = x_local - if x_is_sharded: - x, x_async = gather_along_first_dim( - x_local, - tensor_parallel_group, - async_op=True, - ) + x_local = input + with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + if with_quantized_compute: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(columnwise=True) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + quantizer=input_quantizer, + ) + else: + if not isinstance(x_local, QuantizedTensor): + x_local = input_quantizer(x_local) + x = x_local + else: + if isinstance(x_local, QuantizedTensor): + x_local = x_local.dequantize() + if x_local.dtype != dtype: + x_local = x_local.to(dtype=dtype) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + else: + x = x_local # Compute grad input dx = None @@ -890,110 +706,80 @@ def _functional_backward( # Check weight tensor if weight is None: raise ValueError("Weight tensor is required to compute input grad") - w = convert_tensor( - weight, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if with_fp8_compute and not is_float8_tensor(w): - fp8_dtype = get_fp8_te_dtype( - weight_fp8_meta["recipe"], - fprop_tensor=True, - ) - w = Float8Tensor.to_float8( - w, - fp8_meta=weight_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=True, - ) - elif not with_fp8_compute and is_float8_tensor(w): + w = weight + w_is_quantized = isinstance(w, QuantizedTensor) + if with_quantized_compute and not w_is_quantized: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + elif not with_quantized_compute and w_is_quantized: w = w.dequantize() + if not with_quantized_compute and w.dtype != dtype: + w = w.to(dtype=dtype) - # Construct grad input tensor - if grad_input is not None: - dx = reshape(grad_input, (-1, input_dims[-1])) - elif with_fp8_grad_input: - fp8_dtype = get_fp8_te_dtype( - grad_input_fp8_meta["recipe"], - fprop_tensor=False, - ) - data = torch.empty( - (dy.size(0), weight_dims[1]), - dtype=torch.uint8, - device=device, - ) - dx = Float8Tensor( - data=data, - fp8_meta=grad_input_fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - dx = torch.empty( - (dy.size(0), weight_dims[1]), - dtype=dtype, - device=device, - ) - - # Perform dgrad GEMM + # Synchronize tensor-parallel communication _wait_async(dy_async) dy_async = None - if with_fp8_compute: - kwargs = {"accumulate": accumulate_into_grad_input, "out": dx} - if with_fp8_grad_input: - if dx._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = dx._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty( - 1, 1, dtype=torch.float32, device=device - ) - fp8_meta.scale_inv = dx._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dx._fp8_meta_forward, - ) - fp8_meta = dx._fp8_meta[fp8_meta_key] - fp8_meta_index = dx._fp8_meta_index - kwargs.update( - { - "out": dx._data, - "out_index": fp8_meta_index, - "fp8_meta_tensor": fp8_meta, - "D_dtype": dx._fp8_dtype, - } + + # Check grad input tensor + dx = grad_input + if dx is None: + if not with_quantized_compute: + grad_input_quantizer = None + if tensor_parallel_mode == "column": + grad_input_quantizer = None + elif isinstance(dx, QuantizedTensor): + if not with_quantized_compute: + raise ValueError( + "Grad input tensor is quantized, but quantized compute is not enabled" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Grad input tensor is quantized, " + "but column tensor parallelism does not support quantized grad input" + ) + if grad_input_quantizer is None: + grad_input_quantizer = getattr(dx, "_quantizer", None) + if grad_input_quantizer is None: + raise ValueError( + "Grad input tensor is quantized, but quantizer was not provided" ) - fp8_gemm( - w.transpose_2d(), - w._scale_inv, - 0, - w._fp8_dtype, - dy._data, - dy._scale_inv, - 0, - dy._fp8_dtype, - dx.dtype, - get_workspace(), - **kwargs, - ) else: - gemm( - w, - dy, - dx.dtype, - get_workspace(), - accumulate=accumulate_into_grad_input, - layout="NN", - out=dx, + grad_input_quantizer = None + if isinstance(grad_input_quantizer, MXFP8Quantizer): + raise RuntimeError( + "Attempting to generate MXFP8 grad input tensor, " + "but GEMM with MXFP8 output is not supported" ) + # Check if accumulating into grad input tensor + if accumulate_into_grad_input: + if dx is None: + raise ValueError( + "Attempted to accumulate into grad input tensor " + "without providing grad input tensor" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Accumulating into grad input tensor " + "is not supported with column tensor parallelism" + ) + + # Perform dgrad GEMM + dx, *_ = general_gemm( + w, + dy, + get_workspace(), + out_dtype=dtype, + quantization_params=grad_input_quantizer, + accumulate=accumulate_into_grad_input, + layout="NN", + out=dx, + use_split_accumulator=_2X_ACC_DGRAD, + grad=True, + ) + # Reduce tensor-parallel grad input if needed if tensor_parallel_mode == "column": if sequence_parallel: @@ -1009,59 +795,46 @@ def _functional_backward( async_op=True, ) - # Perform wgrad GEMM - if not weight_requires_grad: - grad_weight = None - else: - if grad_weight is None: + # Compute grad weight + dw = None + if weight_requires_grad: + + # Synchronize tensor-parallel communication + _wait_async(x_async) + _wait_async(dy_async) + x_async = None + dy_async = None + + # Check grad input tensor + dw = grad_weight + dw_dtype = dtype + if dw is None: if accumulate_into_grad_weight: raise ValueError( - "Attempted to accumulate into grad weight buffer" - "without providing grad weight" + "Attempted to accumulate into grad weight tensor " + "without providing grad weight tensor" ) - grad_weight = torch.empty( - weight_dims, - dtype=dtype, - device=device, - memory_format=torch.contiguous_format, - ) - _wait_async(dy_async) - _wait_async(x_async) - dy_async = None - x_async = None - if with_fp8_compute: - fp8_gemm( - x.transpose_2d(), - x._scale_inv, - 0, - x._fp8_dtype, - dy.transpose_2d(), - dy._scale_inv, - 0, - dy._fp8_dtype, - grad_weight.dtype, - get_workspace(), - accumulate=accumulate_into_grad_weight, - out=grad_weight, - ) else: - gemm( - x, - dy, - x.dtype, - get_workspace(), - accumulate=accumulate_into_grad_weight, - layout="NT", - out=grad_weight, - ) + dw_dtype = dw.dtype + + # Perform wgrad GEMM + dw, *_ = general_gemm( + x, + dy, + get_workspace(), + out_dtype=dw_dtype, + accumulate=accumulate_into_grad_weight, + layout="NT", + out=dw, + use_split_accumulator=_2X_ACC_WGRAD, + grad=True, + ) # Clean up and return grads _wait_async(dy_async) _wait_async(x_async) _wait_async(dx_async) - if dx is not None and grad_input is None: - grad_input = reshape(dx, input_dims) - return grad_input, grad_weight + return dx, dw def op_forward( self, @@ -1071,21 +844,33 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + # Check which grads are required + input_requires_grad = ctx.requires_grad and input_.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight.requires_grad + # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = self.get_fp8_meta("input") - weight_fp8_meta = self.get_fp8_meta("param") - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - grad_output_fp8_meta = self.get_fp8_meta("grad_output") - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + + # Get quantizers + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + if next_op is not None and next_op.num_quantizers("forward") > 0: + output_quantizer = next_op.get_quantizer("forward", 0) + grad_output_quantizer = self.get_quantizer("backward", 0) + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) + + # Configure quantizers + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + input_quantizer.set_usage(columnwise=weight_requires_grad) + weight_quantizer.set_usage(columnwise=False) # Get autocast dtype if needed dtype = None @@ -1096,27 +881,26 @@ def op_forward( output, x_local, _ = BasicLinear._functional_forward( input=input_, weight=self.weight, - device=self.device, dtype=dtype, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass ctx.save_for_backward(x_local) - ctx.with_fp8_compute = with_fp8_compute - ctx.weight_fp8_meta = weight_fp8_meta - ctx.grad_output_fp8_meta = grad_output_fp8_meta - ctx.grad_input_fp8_meta = grad_input_fp8_meta + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizer = input_quantizer + ctx.weight_quantizer = weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer ctx.dtype = dtype - ctx.input_dims = input_.size() - ctx.input_requires_grad = input_.requires_grad - ctx.weight_requires_grad = self.weight.requires_grad + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad ctx.has_prev_op = prev_op is not None return output @@ -1149,21 +933,19 @@ def op_backward( grad_output=grad_output, input=x_local, weight=self.weight, - input_dims=ctx.input_dims, - weight_dims=self.weight.size(), input_requires_grad=ctx.input_requires_grad, weight_requires_grad=ctx.weight_requires_grad, - device=self.device, dtype=ctx.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, - with_fp8_compute=ctx.with_fp8_compute, - weight_fp8_meta=ctx.weight_fp8_meta, - grad_output_fp8_meta=ctx.grad_output_fp8_meta, - grad_input_fp8_meta=ctx.grad_input_fp8_meta, + with_quantized_compute=ctx.with_quantized_compute, + input_quantizer=ctx.input_quantizer, + weight_quantizer=ctx.weight_quantizer, + grad_output_quantizer=ctx.grad_output_quantizer, + grad_input_quantizer=ctx.grad_input_quantizer, ) # Clear input tensor if possible diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 65717d5fa5..c5897486e3 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -13,13 +13,9 @@ import torch from transformer_engine_torch import layernorm_bwd, layernorm_fwd -from ...cpp_extensions import ( - layernorm_fwd_fp8, - layernorm_fwd_fp8_inf, - layernorm_fwd_inf, -) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...constants import TE_DType +from ...tensor import QuantizedTensor from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -213,60 +209,28 @@ def op_forward( # Check if backward pass is needed requires_grad = ctx.requires_grad - # Check if FP8 is enabled - with_fp8_output = ( + # Check if output is quantized + output_quantizer = None + if ( FP8GlobalStateManager.is_fp8_enabled() and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ) - output_fp8_meta = None - if with_fp8_output: - output_fp8_meta = next_op.get_fp8_meta("input") + and next_op.num_quantizers("forward") > 0 + ): + output_quantizer = next_op.get_quantizer("forward", 0) # Compute layer norm - y = None - means = None - rstdevs = None sm_margin = self._sm_margins["forward" if requires_grad else "inference"] - if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) - args = ( - x, - w, - b, - self.eps, - output_fp8_meta[fp8_meta_key], - 0, # fp8_meta_index - fp8_dtype, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - data, means, rstdevs = layernorm_fwd_fp8(*args) - else: - data = layernorm_fwd_fp8_inf(*args) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - args = ( - x, - w, - b, - self.eps, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - y, means, rstdevs = layernorm_fwd(*args) - else: - y = layernorm_fwd_inf(*args) + y, means, rstdevs = layernorm_fwd( + x, + w, + b, + self.eps, + None, + output_quantizer, + TE_DType[dtype], + sm_margin, + self.zero_centered_gamma, + ) # Save state for backward pass if requires_grad: diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index e3755decd6..448954fc69 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -9,8 +9,8 @@ import torch -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor from ..op import BasicOperation, OperationContext @@ -38,10 +38,10 @@ def __init__( self._quantize_forward = forward self._quantize_backward = backward - def num_fp8_scales(self, mode: str) -> int: - if mode == "input" and self._quantize_forward: + def num_quantizers(self, mode: str) -> int: + if mode == "forward" and self._quantize_forward: return 1 - if mode == "grad_output" and self._quantize_backward: + if mode == "backward" and self._quantize_backward: return 1 return 0 @@ -61,15 +61,7 @@ def op_forward( # Quantize if needed out = input_ if quantize_forward and not isinstance(out, QuantizedTensor): - fp8_meta = self.get_fp8_meta("input") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - out = Float8Tensor.to_float8( - out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) + out = self.get_quantizer("forward", 0)(out) ctx.quantize_backward = quantize_backward return out @@ -81,13 +73,5 @@ def op_backward( ) -> tuple[torch.Tensor, tuple[()]]: grad_input = grad_output if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): - fp8_meta = self.get_fp8_meta("grad_output") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - grad_input = Float8Tensor.to_float8( - grad_input, - fp8_meta=fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) + grad_input = self.get_quantizer("backward", 0)(grad_input) return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 03a02786b4..adfd46641b 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -9,9 +9,9 @@ import torch -from ...tensor import Float8Tensor, QuantizedTensor +from ...distributed import gather_along_first_dim +from ...tensor import QuantizedTensor from ..op import BasicOperation, OperationContext -from .._common import convert_tensor class ReduceScatter(BasicOperation): @@ -45,7 +45,7 @@ def op_forward( # Trivial case if self.process_group_size == 1: - return input_ + return input_.detach() # Tensor dimensions input_dims = input_.size() @@ -74,47 +74,9 @@ def op_backward( ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: - - # Trivial case + grad_input: torch.Tensor if self.process_group_size == 1: - return grad_output, () - - # Tensor dimensions - output_dims = grad_output.size() - if not output_dims: - raise RuntimeError( - "Attempted to all-gather a tensor " - f"with shape={list(output_dims)} " - f"over {self.process_group_size} processes" - ) - input_dims = list(output_dims) - input_dims[0] *= self.process_group_size - - # Perform all-gather - dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) - dx = None - if isinstance(dy, Float8Tensor): - dx = Float8Tensor.make_like( - dy, - data=torch.empty( - input_dims, - dtype=torch.uint8, - device=dy.device, - ), - ) - torch.distributed.all_gather_into_tensor( - dx._data, - dy._data, - group=self.process_group, - ) + grad_input = grad_output.detach() else: - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() - dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) - torch.distributed.all_gather_into_tensor( - dx, - dy, - group=self.process_group, - ) - - return dx, () + grad_input, _ = gather_along_first_dim(grad_output, self.process_group) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index 53524cdd83..1e9095169c 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -14,7 +14,6 @@ BasicOperation, OperationContext, ) -from .._common import reshape class Reshape(BasicOperation): @@ -42,11 +41,11 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: ctx.input_shape = input_.size() - return reshape(input_, self._shape) + return input_.reshape(*self._shape) def op_backward( self, ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: - return reshape(grad_output, ctx.input_shape), () + return grad_output.reshape(*ctx.input_shape), () diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 32ef242b90..c1f32af93a 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -13,13 +13,9 @@ import torch from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd -from ...cpp_extensions import ( - rmsnorm_fwd_fp8, - rmsnorm_fwd_fp8_inf, - rmsnorm_fwd_inf, -) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor +from ...constants import TE_DType from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -193,57 +189,27 @@ def op_forward( # Check if backward pass is needed requires_grad = ctx.requires_grad - # Check if FP8 is enabled - with_fp8_output = ( + # Check if output is quantized + output_quantizer = None + if ( FP8GlobalStateManager.is_fp8_enabled() and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ) - output_fp8_meta = None - if with_fp8_output: - output_fp8_meta = next_op.get_fp8_meta("input") + and next_op.num_quantizers("forward") > 0 + ): + output_quantizer = next_op.get_quantizer("forward", 0) # Compute RMSNorm - y = None - rstdevs = None sm_margin = self._sm_margins["forward" if requires_grad else "inference"] - if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) - args = ( - x, - w, - self.eps, - output_fp8_meta[fp8_meta_key], - 0, # fp8_meta_index - fp8_dtype, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - data, rstdevs = rmsnorm_fwd_fp8(*args) - else: - data = rmsnorm_fwd_fp8_inf(*args) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - args = ( - x, - w, - self.eps, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - y, rstdevs = rmsnorm_fwd(*args) - else: - y = rmsnorm_fwd_inf(*args) + y, _, rstdevs = rmsnorm_fwd( + x, + w, + self.eps, + None, + output_quantizer, + TE_DType[dtype], + sm_margin, + self.zero_centered_gamma, + ) # Save state for backward pass if requires_grad: diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 1ddd8d116c..e295929e98 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -73,11 +73,8 @@ def fuser_backward( grad_output=grad_output, input=x_local, weight=linear_op.weight, - input_dims=linear_op_ctx.input_dims, - weight_dims=linear_op.weight.size(), input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, - device=linear_op.device, dtype=grad_input.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, @@ -86,10 +83,11 @@ def fuser_backward( tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=linear_op_ctx.with_fp8_compute, - weight_fp8_meta=linear_op_ctx.weight_fp8_meta, - grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, - grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + with_quantized_compute=linear_op_ctx.with_quantized_compute, + input_quantizer=linear_op_ctx.input_quantizer, + weight_quantizer=linear_op_ctx.weight_quantizer, + grad_output_quantizer=linear_op_ctx.grad_output_quantizer, + grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) if accumulate_into_main_grad: grad_weight = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index c746f21f2c..6088b3c0db 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -83,22 +83,22 @@ def fuser_forward( raise NotImplementedError("Activations are not yet supported") # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = linear_op.get_fp8_meta("input") - weight_fp8_meta = linear_op.get_fp8_meta("param") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) next_op = basic_op_next_ops[-1] - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + if next_op is not None and next_op.num_quantizers("forward") > 0: + output_quantizer = next_op.get_quantizer("forward", 0) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) # Get autocast dtype if needed dtype = None @@ -110,25 +110,24 @@ def fuser_forward( input=input_, weight=linear_op.weight, bias=bias, - device=linear_op.device, dtype=dtype, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass linear_op_ctx.save_for_backward(x_local) - linear_op_ctx.with_fp8_compute = with_fp8_compute - linear_op_ctx.weight_fp8_meta = weight_fp8_meta - linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta - linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index fa7f07cb95..69b0c3ba5a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -77,19 +77,19 @@ def fuser_forward( raise ValueError("Bias operation forward does not expect keyword arguments") # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = linear_op.get_fp8_meta("input") - weight_fp8_meta = linear_op.get_fp8_meta("param") - grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) # Get autocast dtype if needed dtype = None @@ -102,26 +102,25 @@ def fuser_forward( input=input_, weight=linear_op.weight, bias=bias, - device=linear_op.device, out=output, accumulate_into_out=True, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass linear_op_ctx.save_for_backward(x_local) - linear_op_ctx.with_fp8_compute = with_fp8_compute - linear_op_ctx.weight_fp8_meta = weight_fp8_meta - linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta - linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index dab4c8f681..bbb27f86e6 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -4,6 +4,8 @@ """Linear layer backward with Userbuffers communication.""" +# pylint: skip-file ### TODO Debug Userbuffers support + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional @@ -12,11 +14,7 @@ import torch from transformer_engine_torch import CommOverlapAlgo -from ...cpp_extensions import ( - fp8_cast_transpose_bgrad_fused, - fp8_gemm, - gemm, -) +from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -49,6 +47,9 @@ def __init__( reduce_scatter: Optional[ReduceScatter], ) -> None: + ### TODO Debug Userbuffers support + raise NotImplementedError("Userbuffers support has been broken by recent refactors") + # Basic operations that comprise this fused operation op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} ops = [] @@ -706,6 +707,8 @@ def fuse_userbuffers_backward_linear( """ + return ops ### TODO Debug Userbuffers support + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 1f3635eb4b..a08c0a6ef9 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -4,6 +4,8 @@ """Linear layer forward with Userbuffers communication.""" +# pylint: skip-file ### TODO Debug Userbuffers support + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional @@ -11,7 +13,7 @@ import torch from transformer_engine_torch import CommOverlapAlgo -from ...cpp_extensions import fp8_gemm, gemm +from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -49,6 +51,9 @@ def __init__( reduce_scatter: Optional[ReduceScatter], ) -> None: + ### TODO Debug Userbuffers support + raise NotImplementedError("Userbuffers support has been broken by recent refactors") + # Basic operations that comprise this fused operation op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} ops = [linear] @@ -524,6 +529,8 @@ def fuse_userbuffers_forward_linear( """ + return ops ### TODO Debug Userbuffers support + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 30367d2c5e..8346d31a40 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -13,13 +13,14 @@ import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import ( - DelayedScaling, +from transformer_engine.common.recipe import Recipe +from ..fp8 import ( + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, FP8GlobalStateManager, - get_default_fp8_recipe, + RecipeState, ) -from ._common import canonicalize_device +from ..tensor import Quantizer @dataclasses.dataclass @@ -174,132 +175,148 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): def __init__(self) -> None: super().__init__() - # FP8 metadata objects + # Objects for quantization + self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None @property def is_fused_op(self) -> bool: return False - def num_fp8_scales( + def num_quantizers( self, mode: str, # pylint: disable=unused-argument ) -> int: - """Number of FP8 scaling factors + """Number of quantizers + + Matches number of quantized tensors used in operation. Parameters ---------- - mode: {"input", "param", "grad_output"} - Type of FP8 scaling factor + mode: {"forward", "backward"} + Quantizer type """ return 0 - def _make_fp8_metas(self) -> dict[str, Optional[dict[str, Any]]]: - """Construct FP8 metadata""" - - # Shared objects for FP8 metadata - dtype = torch.float32 - device = canonicalize_device(None) - recipe = get_default_fp8_recipe() - - def _make_meta( - num_scales: int, - is_forward: bool, - ) -> Optional[dict[str, Any]]: - """Construct FP8 metadata for one tensor type""" - if num_scales == 0: - return None - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(num_scales, dtype=dtype, device=device) - meta.scale_inv = torch.ones(num_scales, dtype=dtype, device=device) - meta.amax_history = torch.zeros( - (recipe.amax_history_len, num_scales), - dtype=dtype, - device=device, + def _reset_quantization_recipe_state( + self, + *, + recipe: Optional[Recipe] = None, + ) -> None: + """Construct state for quantization recipe""" + + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() + + # Quantization recipe state for forward and backward pass + self._fp8_metas = {"forward": None, "backward": None} + self._quantizers = {"forward": [], "backward": []} + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue + + # Construct quantization recipe state + recipe_state = RecipeState.create( + recipe, + mode=mode, + num_quantizers=num_quantizers, ) - return { - key: meta, + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + self._fp8_metas[mode] = { + fp8_meta_key: recipe_state, "recipe": recipe, - "fp8_group": None, + "fp8_group": FP8GlobalStateManager.get_fp8_group(), } - # Construct FP8 metadata for all tensor types - return { - "input": _make_meta(self.num_fp8_scales("input"), True), - "param": _make_meta(self.num_fp8_scales("param"), True), - "grad_output": _make_meta(self.num_fp8_scales("grad_output"), False), - } - - @classmethod - def _maybe_update_fp8_meta( - cls, - fp8_meta: Optional[dict[str, Any]], + # Construct builder class for quantized tensors + self._quantizers[mode] = recipe_state.make_quantizers() + + def _update_quantization_recipe_state( + self, *, - fp8_recipe: Optional[DelayedScaling] = None, + recipe: Optional[Recipe] = None, ) -> None: - if fp8_meta is None: - return + """Make sure quantizer state matches quantization recipe""" - # Update FP8 recipe - if fp8_recipe is None: - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = fp8_recipe + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() - # Update FP8 communication group - fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - - # Adjust amax history length if needed - amax_history_len = fp8_recipe.amax_history_len - for is_forward in (True, False): - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if fp8_meta_key not in fp8_meta: + # Reset quantization state if needed + if self._fp8_metas is None or self._quantizers is None: + self._reset_quantization_recipe_state(recipe=recipe) + return + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: continue - meta = fp8_meta[fp8_meta_key] - curr_len = meta.amax_history.size(0) - - # Nothing to be done if amax history is already correct - if curr_len == amax_history_len: + recipe_state = self._fp8_metas[mode][fp8_meta_key] + need_to_reset_recipe_state = ( + recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + if need_to_reset_recipe_state: + self._reset_quantization_recipe_state(recipe=recipe) + return + + # Quantization recipe state for forward and backward pass + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: continue - # Reallocate amax history - with torch.no_grad(): - if curr_len > amax_history_len: - meta.amax_history = meta.amax_history[:amax_history_len].clone() - else: - meta.amax_history = torch.nn.functional.pad( - meta.amax_history, - pad=(0, 0, 0, amax_history_len - curr_len), - ) + # Update FP8 metadata + fp8_meta = self._fp8_metas[mode] + fp8_meta["recipe"] = recipe + fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - # Update global buffers for amax reductions - buffer_info_key = FP8GlobalStateManager.get_buffer_info() - if buffer_info_key in fp8_meta: - fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] - for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ - fp8_meta_key - ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ - fp8_meta_key - ].amax_history - - def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: - """FP8 metadata + # Get recipe state + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = fp8_meta[fp8_meta_key] + + # Reallocate amax history if needed + if recipe.mxfp8(): + continue + + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len + if current_length != target_length: + with torch.no_grad(): + if target_length < current_length: + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + else: + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) + self._quantizers[mode] = recipe_state.make_quantizers() + + def get_quantizer( + self, + mode: str, + index: int, + ) -> Quantizer: + """Get builder class for quantized tensor Parameters ---------- - mode: {"input", "param", "grad_output"} - Type of FP8 scaling factor + mode: {"forward", "backward"} + Quantizer type """ - if self._fp8_metas is None: - self._fp8_metas = self._make_fp8_metas() - return self._fp8_metas[mode] + if self._quantizers is None: + self._reset_quantization_recipe_state() + return self._quantizers[mode][index] @torch.no_grad() def _save_fp8_metas(self) -> Optional[dict[str, Any]]: @@ -321,7 +338,6 @@ def _save_fp8_metas(self) -> Optional[dict[str, Any]]: continue out[mode][fp8_meta_key] = ( fp8_meta[fp8_meta_key].scale.clone(), - fp8_meta[fp8_meta_key].scale_inv.clone(), fp8_meta[fp8_meta_key].amax_history.clone(), ) return out @@ -346,16 +362,15 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: assert ( fp8_meta_key in self._fp8_metas[mode] ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" - scale, scale_inv, amax_history = tensors + scale, amax_history = tensors self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) - self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) def pre_forward( self, *, fp8_enabled: Optional[bool] = None, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, ) -> None: """Preprocessing before forward pass""" @@ -363,28 +378,15 @@ def pre_forward( if fp8_enabled is None: fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: - - # Construct FP8 metadata if needed - if self._fp8_metas is None: - self._fp8_metas = self._make_fp8_metas() - - # Make sure FP8 metadata matches FP8 autocast context - for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) - - # Register FP8 metadata for amax and scale update + self._update_quantization_recipe_state(recipe=fp8_recipe) if not FP8GlobalStateManager.fp8_graph_capturing(): - if self.num_fp8_scales("input"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("input"), - ) - if self.num_fp8_scales("param"): + if self.num_quantizers("forward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("param"), + self._fp8_metas["forward"], ) - if self.num_fp8_scales("grad_output"): + if self.num_quantizers("backward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("grad_output"), + self._fp8_metas["backward"], ) @abc.abstractmethod @@ -527,13 +529,6 @@ def get_extra_state(self) -> torch.Tensor: # See: https://github.com/NVIDIA/TransformerEngine/pull/351 # See: https://github.com/NVIDIA/TransformerEngine/pull/363 - # Return immediately if op has no FP8 state - has_fp8_state = any( - self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") - ) - if not has_fp8_state: - return torch.Tensor() - def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -547,25 +542,20 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Store FP8 state state = {} - for mode in ("input", "param", "grad_output"): + for mode in ("forward", "backward"): # Get state for a given FP8 tensor - if self.num_fp8_scales(mode) == 0: - state[mode] = None + if self.num_quantizers(mode) == 0: continue fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue state[mode] = {} # Store tensors if "scaling_fwd" in fp8_meta: state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv) state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) if "scaling_bwd" in fp8_meta: state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv) state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items @@ -591,7 +581,7 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: # Deserialize state from byte tensor state = pickle.loads(state.detach().numpy(force=True).tobytes()) - if state is None: + if state is None or len(state) == 0: return def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: @@ -606,12 +596,12 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load FP8 state - for mode in ("input", "param", "grad_output"): + for mode in ("forward", "backward"): # Get state for a given FP8 tensor if mode not in state: continue - if self.num_fp8_scales(mode) == 0: + if self.num_quantizers(mode) == 0: continue fp8_meta = self.get_fp8_meta(mode) if fp8_meta is None: @@ -631,12 +621,10 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: if "scaling_fwd" in fp8_meta: fp8_meta_fwd = fp8_meta["scaling_fwd"] copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv) copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) if "scaling_bwd" in fp8_meta: fp8_meta_bwd = fp8_meta["scaling_bwd"] copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv) copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) # Finish CPU-GPU memory transfers diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b86c973304..d972fd96ab 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -8,24 +8,20 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from .multi_tensor_apply import multi_tensor_applier -from ..float8_tensor import Float8Tensor def get_fp8_meta(fp8_tensor): """FP8 metadata getter.""" - if fp8_tensor._fp8_meta is None: - raise RuntimeError("FP8 meta data is not initialized.") + assert isinstance(fp8_tensor, Float8Tensor), "Fused optimizer supports only Float8Tensor class" + if fp8_tensor._quantizer is None: + raise RuntimeError("FP8 quantizer data is not initialized.") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=fp8_tensor._fp8_meta_forward, - ) + quantizer = fp8_tensor._quantizer - fp8_meta_index = fp8_tensor._fp8_meta_index - scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index] - amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + scale = quantizer.scale + amax = quantizer.amax scale_inv = fp8_tensor._scale_inv return scale, amax, scale_inv @@ -237,6 +233,10 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) + assert len(scaled_state._quantizer.scale) == 1, ( + "Only scaling with one scaling factor per tensor is supported by the" + " FusedAdam." + ) else: assert scaled_state.dtype == dtype @@ -251,7 +251,7 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device) torch.div(absmax, max_range, out=scale) if isinstance(scaled_state, Float8Tensor): - scaled_state._scale_inv.copy_(scale) + scaled_state._quantizer.scale.copy_(1 / scale) scaled_state.copy_(unscaled_state) else: rscale = torch.where(scale > 0, scale.reciprocal(), 0.0) @@ -269,7 +269,6 @@ def get_unscaled_state(self, param, state_name): state = self.state[param] dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: - assert isinstance(state[state_name], Float8Tensor) unscaled = state[state_name].float() elif dtype == torch.float16: assert state[state_name].dtype == torch.float16 @@ -343,12 +342,15 @@ def _initialize_state( data.zero_() if dtype == torch.uint8: - self.state[param][state_name] = Float8Tensor( - data=data, - dtype=torch.float32, - fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device), + quantizer = Float8Quantizer( + scale=torch.ones([1], dtype=torch.float32, device=param.device), + amax=torch.zeros([1], dtype=torch.float32, device=param.device), + fp8_dtype=tex.DType.kFloat8E4M3, ) + self.state[param][state_name] = quantizer.make_empty(param.shape) + self.state[param][state_name].quantize_(data.float()) else: + self.state[param][state_name] = data # Create scale if necessary. @@ -421,6 +423,8 @@ def load_state_dict(self, state_dict): param = id_map[k] self.state[param] = {} for name in v: + if v[name] is None: + continue if ( self.store_param_remainders and name == "master_param" diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 264b620be8..2e6167a6e0 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -48,8 +48,12 @@ def forward( # Data type check fp8 = isinstance(inp, Float8Tensor) if fp8: + assert ( + inp._quantizer.scale.ndim == 0 + ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute." dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: dtype = TE_DType[inp.dtype] @@ -78,7 +82,11 @@ def forward( if fp8: permuted_act = Float8Tensor( - data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + data=permuted_act, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=permuted_act.shape, + dtype=fake_dtype, ) ctx.row_id_map = row_id_map @@ -107,6 +115,7 @@ def backward( ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data else: dtype = TE_DType[permuted_act_grad.dtype] @@ -118,7 +127,11 @@ def backward( ) if ctx.fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK + data=act_grad, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv * ctx.topK, + shape=act_grad.shape, + dtype=fake_dtype, ) return act_grad, None, None, None @@ -167,6 +180,7 @@ def forward( if fp8: dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: dtype = TE_DType[inp.dtype] @@ -181,7 +195,11 @@ def forward( if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=unpermuted_output.shape, + dtype=fake_dtype, ) ctx.save_for_backward(inp, row_id_map, probs) @@ -207,6 +225,7 @@ def backward( ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv + fake_dtype = unpermuted_act_grad.dtype unpermuted_act_grad = unpermuted_act_grad._data else: dtype = TE_DType[unpermuted_act_grad.dtype] @@ -220,7 +239,13 @@ def backward( unpermuted_act_grad, inp, dtype, row_id_map, probs ) if ctx.fp8: - act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv) + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) if not ctx.needs_input_grad[2]: prob_grad = None diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index d3b3f03e10..20503fea2f 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -12,10 +12,9 @@ from pathlib import Path import setuptools -from torch.utils.cpp_extension import BuildExtension try: - import torch # pylint: disable=unused-import + from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e @@ -57,7 +56,7 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["torch"], - tests_require=["numpy", "onnxruntime", "torchvision"], + tests_require=["numpy", "torchvision"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 3950c071b6..25362e1d58 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -7,11 +7,7 @@ from typing import Callable, Tuple, Union, Optional import torch from torch import nn -import torch._C._onnx as _C_onnx -from torch.onnx import _type_utils import transformer_engine_torch as tex -from transformer_engine.pytorch.export import is_in_onnx_export_mode -from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32 THREADS_PER_WARP = 32 @@ -32,35 +28,6 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: return _default_causal_mask[matrix_identifiers] -def _get_onnx_export_causal_mask( - seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor -) -> torch.Tensor: - """Return the causal upper triangular mask for softmax input, for ONNX export. - - ONNX does not support dynamic control-flow and requires non-square masks when - using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1). - - Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct - shape for GPT context and generation phases. - In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in - the generation phase the mask is rectangular with shape (1, seq_k). - """ - assert len(onnx_causal_mask.size()) == 2 - assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1) - assert onnx_causal_mask.size(0) >= (seq_k - seq_q) >= 0 - derived_mask = onnx_causal_mask[seq_k - seq_q : seq_k, :seq_k] - return derived_mask - - -def fp32_compute(onnx_symbolic_fn): - """A decorator that wraps an ONNX symoblic function with FP32 compute operators.""" - - def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs): - return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs) - - return wrapper - - class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ Fused operation which performs following three operations in sequence @@ -88,34 +55,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] return input_grads, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledUpperTriangMaskedSoftmax symbolic method""" - - def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 - ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) - k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - mask = g.op("Trilu", ones, k, upper_i=1) - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - return mask - - # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward - mask = triangular_mask() - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): """ @@ -143,40 +82,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] return input_grads, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledAlignedCausalMaskedSoftmax symbolic method""" - - def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 - ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) - k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - - # rectangular causal mask aligned to the bottom right corner of Attention matrix - rows = inputs.size(dim=-2) - cols = inputs.size(dim=-1) - diag_shift = cols - rows + 1 - - mask = g.op("Trilu", ones, k, upper_i=diag_shift) - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - return mask - - # Captures the logic of function scaled_aligned_masked_softmax_warp_forward - mask = triangular_mask() - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledMaskedSoftmax(torch.autograd.Function): """ @@ -203,30 +108,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None - @staticmethod - @fp32_compute - def symbolic( - g: torch.Graph, inputs: torch._C.Value, mask: torch._C.Value, scale: float - ) -> torch._C.Value: - """ScaledMaskedSoftmax symbolic method""" - # Captures the logic of function scaled_masked_softmax_warp_forward. - # output = softmax(mask(input*scale) - # Computed as: - # masked_scaled = (1 - mask)*(input*scale) - # softmax_mask = mask * -10000 - # output = softmax(masked_scaled + softmax_mask) - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - # Note: type is hard coded because softmax uses FP16 or BF16 - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledSoftmax(torch.autograd.Function): """ @@ -252,15 +133,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledSoftmax symbolic method""" - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - out = g.op("Softmax", scaled) - return out - class FusedScaleMaskSoftmax(nn.Module): """ @@ -281,18 +153,6 @@ def __init__( self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 - # Users exporting to ONNX can optimize the attention mask for GPT text generation. - self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1")) - if self.kvcache_max_seq > 0: - self.register_buffer( - "onnx_causal_mask", - torch.triu( - torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"), - diagonal=1, - ).bool(), - persistent=False, - ) - def forward( self, inp: torch.Tensor, @@ -310,7 +170,7 @@ def forward( assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode(): + if self.is_kernel_available(mask, *inp.size()): return self.forward_fused_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale) @@ -363,8 +223,9 @@ def forward_fused_softmax( """ scale = 1.0 if scale is None else scale - if self.attn_mask_type in ["causal", "causal_bottom_right"]: - return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) + # Disable for now until unalignment bug is fixed. + # if self.attn_mask_type in ["causal", "causal_bottom_right"]: + # return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": @@ -383,13 +244,7 @@ def forward_torch_softmax( if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) - if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: - assert self.kvcache_max_seq >= seq_len_k - causal_mask = _get_onnx_export_causal_mask( - seq_len_q, seq_len_k, self.onnx_causal_mask - ) - else: - causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) if mask is None: mask = causal_mask else: diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py deleted file mode 100755 index 54eb37ecab..0000000000 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -ONNX symbolic functions for Transformer Engine - -Warnings of the type pasted below are a known Pytorch issue -(https://github.com/pytorch/pytorch/issues/81693): - -tests/test_onnx_export.py::test_export_cast_ops[112] - /opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649: - UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing, - so it may result in wrong shape inference for the exported graph. - Please consider adding it in symbolic function. (Triggered internally at - /opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.) - _C._jit_pass_onnx_graph_shape_type_inference( - - -Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access -specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get -the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`): - TypeError: 'torch._C.Value' object is not subscriptable -""" - -import torch -from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils -import torch._C._onnx as _C_onnx - -# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx._internal import jit_utils - -import transformer_engine_torch as tex - - -# This file registers custom op symbolic ONNX functions and does not export any symbols. -__all__ = [] - - -# Custom ops spec version -VER = 1 -UNSPECIFIED_TYPE = -1 - - -def make_op_name(op_name: str) -> str: - """custom op name""" - return "trt::" + op_name - - -def get_TensorProtoDataType(t): - """Return the _C_onnx.TensorProtoDataType of the input tensor""" - try: - return { - "Float": _C_onnx.TensorProtoDataType.FLOAT, - "Half": _C_onnx.TensorProtoDataType.FLOAT16, - "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, - }[t.type().scalarType()] - except KeyError as e: - raise TypeError(f"Onnx export for dtype {t.type().scalarType()} not supported.") from e - - -def is_dtype_fp32(t): - """Check fp32 dtype""" - return t.type().scalarType() == "Float" - - -def is_dtype_fp16(t): - """Check fp16 dtype""" - return t.type().scalarType() == "Half" - - -def is_dtype_bf16(t): - """Check bf16 dtype""" - return t.type().scalarType() == "BFloat16" - - -def quantize(g, inputs, scale, fp8_tensor): - """Helper Function for Quantization""" - output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - - # Q inputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the input if needed. - if not is_dtype_fp32(inputs): - inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor])) - q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) - ) - return q_op - - -def dequantize(g, inputs, scale_inv, fp8_tensor, otype): - """Helper Function for Dequantization""" - output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - - scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) - out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.float32).with_sizes(output_shape) - ) - - # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the output if needed. - if otype == int(tex.DType.kFloat16): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - elif otype == int(tex.DType.kBFloat16): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return out - - -def compute_in_fp32(g, inp, subgraph, *args, **kwargs): - """Wrap subgraph with casts to/from FP32 so that its precision is FP32. - - If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`; - then cast subgraphs's output back to `inp` data type. - """ - inp_dtype = get_TensorProtoDataType(inp) - is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT - if not is_fp32: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - sg_out = subgraph(g, inp, *args, **kwargs) - if not is_fp32: - sg_out = g.op("Cast", sg_out, to_i=inp_dtype) - return sg_out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for cast_to_fp8""" - # pylint: disable=unused-argument - return quantize(g, inputs, scale, fp8_tensor) - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i") -def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for cast_to_fp8_noalloc""" - # pylint: disable=unused-argument - return quantize(g, inputs, scale, fp8_tensor) - - -@symbolic_helper.parse_args("v", "fs", "i", "i", "i") -def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): - """ONNX graph for cast_from_fp8""" - # pylint: disable=unused-argument - return dequantize(g, inputs, scale_inv, fp8_tensor, otype) - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_gelu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_gelu""" - # pylint: disable=unused-argument - # TE computes GELU using float32 precision so wrap the GELU subgraph with - # conversion to/from float32. - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = torch.onnx.symbolic_opset9.gelu(g, inp, "tanh") - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_relu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_relu""" - # pylint: disable=unused-argument - out = torch.onnx.symbolic_opset9.relu(g, inp) - if scale: - out = quantize(g, out, scale, fp8_tensor) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for swiglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - out = g.op("Mul", g.op("Sigmoid", first), second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_swiglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_swiglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_swiglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_reglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for reglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - out = g.op("Mul", g.op("Relu", first), second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_reglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_reglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_reglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_geglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for geglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - first = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") - out = g.op("Mul", first, second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_geglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_geglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_geglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args( - "v", - "fs", - "i", - "i", - "i", - "v", - "fs", - "i", - "i", - "i", - "v", - "fs", - "i", - "v", - "v", - "i", - "v", - "i", - "v", - "i", - "i", - "i", -) -def onnx_te_gemm( - g, - weight, - weight_scale_inverse, - weight_fp8_tensor, - weight_type, - trans_weight, - inputs, - input_scale_inverse, - input_fp8_tensor, - input_type, - trans_input, - out, - out_scale, - out_type, - out_amax, - bias, - bias_type, - pre_gelu_out, - grad, - workspace, - workspaceSize, - accumulate, - use_split_accumulator, -): - """ONNX graph for te_gemm""" - # pylint: disable=unused-argument - is_fp16 = is_dtype_fp16(inputs) - is_bf16 = is_dtype_bf16(inputs) - if input_type == int(tex.DType.kFloat8E4M3): - inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type) - - if weight_type == int(tex.DType.kFloat8E4M3): - weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, out_type) - - empty_tensor_size = [0] - bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size - pre_gelu_out_empty = ( - torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) == empty_tensor_size - ) - - if not bias_empty: - output = g.op("Gemm", inputs, weight, bias, transA_i=trans_input, transB_i=trans_weight) - else: - output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight) - if not bias_empty: - if not pre_gelu_out_empty: - # TE computes GELU using float32 precision so wrap the GELU subgraph with - # conversion to/from float32. - output = compute_in_fp32(g, output, torch.onnx.symbolic_opset9.gelu, "tanh") - else: - if is_fp16: - output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - elif is_bf16: - output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return output - - -def _ones_like(g, inp, dtype): - """Returns a tensor filled with the scalar value 1, with the same size as input and - with dtype data-type""" - shape = g.op("Shape", inp) - # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR - # create a ConstantOfShape with type FP32 and then add a Cast to BF16. - is_bf16 = dtype == torch.bfloat16 - one = g.op( - "ConstantOfShape", - shape, - value_t=torch.tensor([1], dtype=torch.float32 if is_bf16 else dtype), - ) - if is_bf16: - one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return one - - -@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") -def onnx_layernorm_fwd_fp8( - g, - inputs, - weight, - bias, - eps, - scale, - amax, - scale_inv, - fp8_tensor, - otype, - sm_margin, - zero_centered_gamma, -): - """ONNX graph for layernorm_fwd_fp8""" - # pylint: disable=unused-argument - inp_dtype = get_TensorProtoDataType(inputs) - - if inp_dtype != get_TensorProtoDataType(weight): - weight = g.op("Cast", weight, to_i=inp_dtype) - if inp_dtype != get_TensorProtoDataType(bias): - bias = g.op("Cast", bias, to_i=inp_dtype) - - ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale, fp8_tensor) - return fp8_ln - - -@symbolic_helper.parse_args("v", "v", "v", "f", "i", "b") -def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma): - """ONNX graph for layernorm_fwd""" - # pylint: disable=unused-argument - - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) - assert ndim is not None - normalized_shape = list(range(0, ndim)) - # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 - normalized_shape = normalized_shape[1:] - - if zero_centered_gamma: - inputs_dtype = inputs.type().dtype() - one = _ones_like(g, weight, inputs_dtype) - weight = g.op("Add", weight, one) - - axis = -len(normalized_shape) - ln = g.op( - "LayerNormalization", - inputs, - weight, - bias, - epsilon_f=eps, - axis_i=axis, - # This sets the LN compute precision - use FP32 always as does TE. - stash_type_i=_C_onnx.TensorProtoDataType.FLOAT, - ) - return ln - - -@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") -def onnx_rmsnorm_fwd_fp8( - g, - inp, - weight, - eps, - scale, - amax, - scale_inv, - fp8_tensor, - otype, - sm_margin, - zero_centered_gamma, -): - """ONNX graph for rmsnorm_fwd_fp8""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma) - out = quantize(g, out, scale, fp8_tensor) - return out - - -@symbolic_helper.parse_args("v", "v", "f", "i", "b") -def onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma): - """ONNX graph for rmsnorm_fwd""" - # pylint: disable=unused-argument - - # Check dimensions - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inp) - if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inp) - assert ndim is not None - normalized_shape = list(range(0, ndim)) - # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 - normalized_shape = normalized_shape[1:] - axis = -len(normalized_shape) - - # Cast input tensors to FP32 if needed - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - if get_TensorProtoDataType(weight) != _type_utils.JitScalarType.FLOAT: - weight = g.op("Cast", weight, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - # Adjust zero-centered weights - if zero_centered_gamma: - one = _ones_like(g, weight, torch.float32) - weight = g.op("Add", weight, one) - - # Perform compute in FP32 - sum_square = g.op("ReduceSumSquare", inp, axes_i=[axis]) - shape = g.op("Shape", inp, start_i=-1) - shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) - mean_squared = g.op("Div", sum_square, shape_f) - eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) - rms_squared = g.op("Add", mean_squared, eps_tensor) - rms_eps = g.op("Sqrt", rms_squared) - normalized_input = g.op("Div", inp, rms_eps) - out = g.op("Mul", weight, normalized_input) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER) -register_custom_op_symbolic("tex_ts::cast_to_fp8_noalloc_ts", onnx_cast_to_fp8_noalloc, VER) -register_custom_op_symbolic("tex_ts::cast_from_fp8_ts", onnx_cast_from_fp8, VER) -register_custom_op_symbolic("tex_ts::gelu_ts", onnx_fp8_gelu, VER) -register_custom_op_symbolic("tex_ts::relu_ts", onnx_fp8_relu, VER) -register_custom_op_symbolic("tex_ts::reglu_ts", onnx_fp8_reglu, VER) -register_custom_op_symbolic("tex_ts::geglu_ts", onnx_fp8_geglu, VER) -register_custom_op_symbolic("tex_ts::swiglu_ts", onnx_fp8_swiglu, VER) -register_custom_op_symbolic("tex_ts::te_gemm_ts", onnx_te_gemm, VER) -register_custom_op_symbolic("tex_ts::layernorm_fwd_fp8_inf_ts", onnx_layernorm_fwd_fp8, VER) -register_custom_op_symbolic("tex_ts::layernorm_fwd_inf_ts", onnx_layernorm_fwd, VER) -register_custom_op_symbolic("tex_ts::rmsnorm_fwd_fp8_inf_ts", onnx_rmsnorm_fwd_fp8, VER) -register_custom_op_symbolic("tex_ts::rmsnorm_fwd_inf_ts", onnx_rmsnorm_fwd, VER) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index aceaaf5d10..610ec2a777 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,10 +6,12 @@ import torch -from .float8_tensor import Float8Tensor -from .quantized_tensor import QuantizedTensor +from .quantized_tensor import QuantizedTensor, Quantizer -__all__ = ["Float8Tensor", "QuantizedTensor"] +__all__ = [ + "QuantizedTensor", + "Quantizer", +] def _make_module_cast_func(dtype): @@ -22,14 +24,8 @@ def _make_module_cast_func(dtype): def tensor_cast_func(tensor: torch.Tensor) -> torch.Tensor: """Cast tensor dtype""" - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data, - fp8_attrs=tensor._fp8_attrs, - dtype=dtype, - requires_grad=tensor.requires_grad, - ) + if isinstance(tensor, QuantizedTensor): + return tensor.__class__.make_like(tensor, dtype=dtype) if tensor.is_floating_point(): return getattr(tensor, cast_func_name)() return tensor diff --git a/tests/paddle/test_sanity_import.py b/transformer_engine/pytorch/tensor/_internal/__init__.py similarity index 69% rename from tests/paddle/test_sanity_import.py rename to transformer_engine/pytorch/tensor/_internal/__init__.py index 0390f2f6a0..e13014bf75 100644 --- a/tests/paddle/test_sanity_import.py +++ b/transformer_engine/pytorch/tensor/_internal/__init__.py @@ -1,7 +1,4 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. - -import transformer_engine.paddle - -print("OK") +"""Internal data structures for quantized tensors.""" diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py new file mode 100644 index 0000000000..6b816db3b5 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8Tensor""" + +from __future__ import annotations +from typing import Any, Dict, Optional, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: Float8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._data is not None: + # Cast from FP8 + return tex.dequantize(tensor, dtype) + + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class Float8TensorBase: + """Mixin class that holds data attributes of Float8Tensor. + + Float8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _scale_inv: torch.Tensor + + # FP8 transpose cache + _transpose: Optional[torch.Tensor] + _transpose_invalid: bool + + def __new__( + cls, + *args, + data: Optional[torch.Tensor], + fp8_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + data_transpose: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + if cls is Float8TensorBase: + instance = object.__new__(cls) + else: + instance = super().__new__(cls, *args, **kwargs) + instance._data = data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._scale_inv = fp8_scale_inv + instance._transpose = data_transpose + instance._transpose_invalid = instance._transpose is None + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "data": self._data, + "fp8_scale_inv": self._scale_inv, + "fp8_dtype": self._fp8_dtype, + "data_transpose": self._transpose, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + tensors = [self._data, self._transpose] + # self._data = None + # self._transpose = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list""" + self._data = tensors[0] + self._transpose = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._data, self._transpose + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromFloat8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + return self._data.size(*args, **kwargs) + + def __repr__(self): + return ( + "Float8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py new file mode 100644 index 0000000000..d78bd55d9a --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for MXFP8Tensor""" + +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromMXFP8Func(torch.autograd.Function): + """Cast from MXFP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: MXFP8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._rowwise_data is not None: + return tex.dequantize(tensor, dtype) + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class MXFP8TensorBase: + """Mixin class that holds data attributes of MXFP8Tensor. + + MXFP8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + tensors = [self._rowwise_data, self._columnwise_data] + # self._rowwise_data = None + # self._columnwise_data = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromMXFP8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + return self._rowwise_data.size(*args, **kwargs) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "MXFP8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index d356df58dc..da788182a0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,25 +4,18 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple, Iterable import warnings import torch import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..constants import TE_DType as torch_to_transformer_engine_dtype -from ..cpp_extensions import ( - cast_from_fp8, - cast_to_fp8, - fp8_cast_transpose_fused, -) -from ..fp8 import FP8GlobalStateManager -from ..utils import devices_match -from .quantized_tensor import QuantizedTensor +from ..utils import devices_match, non_tn_fp8_gemm_supported +from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc aten = torch.ops.aten -updated_fp8_params = {} _ops_to_preserve_subclass_in_fsdp2 = { torch.ops.aten.empty_like.default, @@ -38,265 +31,142 @@ } -def _make_fp8_attr_property_funcs(name: str) -> Any: - """Make accessors for an FP8 attribute +class Float8Quantizer(Quantizer): + """Builder class for FP8 tensors with per-tensor delayed scaling - We store FP8 attributes in a dictionary so we can share them - between tensors with the same data, e.g. detached tensors. For - convenience, we also expose them as property attributes. This - function creates the accessors for property attributes. - - Parameters - ---------- - name: str - Key in dictionary of FP8 attributes + High-precision tensors (e.g. in FP32 or BF16) are quantized by + multiplying with a scaling factor and casting to FP8. The max-abs + value ("amax") in the tensor is also computed, which can be used + for updating the scaling factor (handled externally by + DelayedScalingRecipeState and FP8GlobalStateManager). """ - def get_func(self) -> Any: - return self._fp8_attrs[name] + """Scaling factor to multiply when quantizing to FP8""" + scale: torch.Tensor + """Max-abs value from last FP8 cast""" + amax: torch.Tensor + """FP8 datatype""" + dtype: TE_DType - def set_func(self, value: Any) -> None: - self._fp8_attrs[name] = value + def __init__( + self, + scale: torch.Tensor, + amax: torch.Tensor, + fp8_dtype: TE_DType, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.scale = scale + self.amax = amax + self.dtype = fp8_dtype - def del_func(self) -> None: - del self._fp8_attrs[name] + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + if not isinstance(dst, Float8Tensor): + raise ValueError("Float8Quantizer can only update Float8Tensor") - return {"fget": get_func, "fset": set_func, "fdel": del_func} + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) -class _FromFloat8Func(torch.autograd.Function): - """Cast from FP8 to other dtype""" + # Update FP8 dtype + dst._fp8_dtype = self.dtype - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: Float8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - return tensor.dequantize(dtype=dtype) + return dst - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -class _ToFloat8Func(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - data: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - data_transpose: Optional[torch.Tensor] = None, + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, ) -> Float8Tensor: - # pylint: disable=missing-function-docstring - # Tensor attributes - dtype = tensor.dtype - if dtype not in (torch.float32, torch.bfloat16, torch.float16): - dtype = torch.float32 - device = tensor.device - if device.type != "cuda": + # Canonicalize tensor attributes + if device is None: device = torch.device("cuda") - # FP8 data buffer - if data is None: - data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) - - # Check scale - if scale is None and fp8_meta is None: - scale = torch.full([1], 1, dtype=torch.float32, device=device) - if scale is not None: - scale = scale.to(device=device, dtype=torch.float32) - - # Check scale-inverse - if scale_inv is None: - scale_inv = torch.empty([1], dtype=torch.float32, device=device) - elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: - scale_inv = scale_inv.to(device=device, dtype=torch.float32) + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) - # Transpose cache - if data_transpose is None and with_transpose_cache: + # Allocate FP8 data transpose if needed + data_transpose = None + if self.columnwise_usage: + inner_dim = data.size(-1) data_transpose = torch.empty( - (data.size(-1), data.numel() // data.size(-1)), + inner_dim, + data.numel() // inner_dim, dtype=torch.uint8, - device=tensor.device, + device=device, ) # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, + return Float8Tensor( + shape=shape, dtype=dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, data_transpose=data_transpose, + quantizer=self, ) - # Cast to FP8 tensor - out.quantize_(tensor, scale=scale, amax=amax) - - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None, None, None - + def calibrate(self, tensor: torch.Tensor) -> None: + amin, amax = tensor.aminmax() + self.amax.copy_(torch.max(-amin, amax)) -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: Float8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = { - "data": tensor._data, - "fp8_meta": tensor._fp8_meta, - "fp8_meta_forward": tensor._fp8_meta_forward, - "fp8_meta_index": tensor._fp8_meta_index, - "fp8_dtype": tensor._fp8_dtype, - "fp8_scale_inv": tensor._scale_inv, - "dtype": tensor.dtype, - } - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return Float8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - # pylint: disable=missing-function-docstring - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), + def create_tensor_from_data( + self, + data: torch.Tensor, + fake_dtype=torch.float32, + requires_grad: bool = False, + internal: bool = False, + ): + """Create Float8Tensor from raw uint8 data""" + assert data.dtype in [ + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + if internal: + return Float8TensorBase( + data=data, + fp8_scale_inv=1 / self.scale, + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, ) - return dgrad, None - return grad.reshape(ctx.shape), None + return Float8Tensor( + shape=data.shape, + dtype=fake_dtype, + data=data, + fp8_scale_inv=1 / self.scale, + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) -class Float8Tensor(QuantizedTensor): +class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -306,256 +176,69 @@ class Float8Tensor(QuantizedTensor): Parameters ---------- + shape: int or iterable of int + Tensor dimensions. + dtype: torch.dtype + Nominal tensor datatype. + requires_grad: bool, optional = False + Whether to compute gradients for this tensor. data: torch.Tensor - Raw FP8 data in a uint8 tensor - fp8_attrs: dict, optional - FP8 metadata, primarily managed by Float8Tensor. If - provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional - FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 - FP8 format. + Raw FP8 data in a uint8 tensor fp8_scale_inv: torch.Tensor - Reciprocal of the scaling factor applied when - casting to FP8, i.e. the scaling factor that must - be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. + Reciprocal of the scaling factor applied when casting to FP8, + i.e. the scaling factor that must be applied when casting from + FP8 to higher precision. + fp8_dtype: transformer_engine_torch.DType + FP8 format. + data_transpose: torch.Tensor, optional + FP8 transpose data in a uint8 tensor + quantizer: Float8Quantizer, optional + Builder class for FP8 tensors """ - _data: torch.Tensor - _fp8_attrs: Dict[str, Any] - _fp8_meta: Optional[Dict[str, Any]] - _fp8_meta_forward: bool - _fp8_meta_index: Optional[int] - _fp8_dtype: TE_DType - _scale_inv: torch.Tensor - - # FP8 transpose cache - _transpose: Optional[torch.Tensor] - _transpose_invalid: bool - - def __new__( - cls, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - requires_grad: bool = False, - data_transpose: Optional[torch.Tensor] = None, - ): - - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=requires_grad, - device=data.device, - ) - self._data = data - - # Initialize dict of class attributes - # Note: We store FP8 attributes in a dictionary so we can - # share them between tensors with the same data, e.g. detached - # tensors. - if fp8_attrs is None: - self._fp8_attrs = {} - else: - self._fp8_attrs = fp8_attrs - return self - - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta = fp8_meta - self._fp8_meta_forward = fp8_meta_forward - self._fp8_meta_index = fp8_meta_index - - # FP8 dtype - assert fp8_dtype in ( - TE_DType.kFloat8E4M3, - TE_DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype = fp8_dtype - - # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if ( - not devices_match(fp8_scale_inv.device, self._data.device) - or fp8_scale_inv.dtype != torch.float32 - ): - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale_inv = fp8_scale_inv - - # FP8 transpose cache - self._transpose = data_transpose - self._transpose_invalid = self._transpose is None - - return self - - def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument - """ - A hook function used in torch fsdp2, called before all-gather - return (all-gather input), (metadata) - Ref: https://github.com/pytorch/pytorch/pull/122908 - - """ - - return (self._data,), (self,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, # pylint: disable=unused-argument - *, - out: Optional[torch.Tensor] = None, - ): - """ - A hook function used in torch fsdp2, called after all-gather - return (Float8Tensor class instance of all-gathered input), (Things to free after forward) - Ref: https://github.com/pytorch/pytorch/pull/122908 - - """ - (data,) = all_gather_outputs - (sample,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - return None - return Float8Tensor.make_like(sample, data=data), all_gather_outputs - - @classmethod - def make_like( - cls, - tensor: Float8Tensor, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Float8Tensor: - """Use attributes of a Float8Tensor to create another Float8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = { - "fp8_meta": tensor._fp8_meta, - "fp8_meta_forward": tensor._fp8_meta_forward, - "fp8_meta_index": tensor._fp8_meta_index, - "fp8_dtype": tensor._fp8_dtype, - "fp8_scale_inv": tensor._scale_inv, - "dtype": tensor.dtype, - } - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) - - def __repr__(self): + def __repr__(self, *, tensor_contents=None): return ( "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " - f"data={self.from_float8(dtype=self.dtype)}" + f"data={self.dequantize(dtype=self.dtype)}" ")" ) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ # Convert PyTorch dtype to TE dtype if dtype is None: dtype = self.dtype - dtype = torch_to_transformer_engine_dtype[dtype] - # Make sure FP8 data is in expected format - data = self._data - if data.device.type != "cuda": - data = data.cuda() - if not data.is_contiguous(): - data = data.contiguous() - if data.dim() != 2: - data = data.view(1, -1) - - # Cast from FP8 - out = cast_from_fp8( - data.view(1, -1), - None, # fp8_meta_tensor - None, # fp8_tensor - self._fp8_dtype, - dtype, - scale_inv=self._scale_inv, - ) + if torch.is_grad_enabled(): + return _FromFloat8Func.apply(self, dtype) + return _FromFloat8Func.forward(None, self, dtype) - # Make sure output is in expected format - if out.size() != self.size(): - out = out.view(self.size()) - return out + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from Float8Tensor + Quantizer can be used for in-place operations. - By default the resulting tensor's dtype is the - Float8Tensor's nominal dtype. """ - return _FromFloat8Func.apply(self, dtype) + if self._quantizer is not None: + return self._quantizer + return Float8Quantizer( + scale=torch.reciprocal(self._scale_inv), + amax=torch.empty(1, dtype=torch.float32, device=self.device), + fp8_dtype=self._fp8_dtype, + ) def quantize_( self, tensor: torch.Tensor, *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, ) -> Float8Tensor: """Update FP8 data @@ -564,181 +247,47 @@ def quantize_( ---------- tensor: torch.Tensor Tensor to copy from - scale: torch.Tensor, optional - Scaling factor to use for FP8 quantization - amax: torch.Tensor, optional - History of maximum absolute values. The first entry will - be updated with the absmax of `tensor`. noop_flag: torch.Tensor, optional float32 flag indicating whether to avoid performing update """ - src = tensor - dst = self - - # In-place operations invalidate transpose cache - self._reset_caches() - - # Special logic if other tensor is Float8Tensor - if isinstance(src, Float8Tensor): - - # Cast to plain tensor if FP8 dtypes don't match - if dst._fp8_dtype != src._fp8_dtype: - return dst.quantize_(src.dequantize()) - - # Directly copy FP8 data - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.detach()) - if amax is not None or dst._fp8_meta is not None: - src_amax: torch.Tensor - if src._fp8_meta is None: - src_min, src_max = src.dequantize().aminmax() - src_amax = torch.maximum(-src_min, src_max) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=src._fp8_meta_forward, - ) - fp8_meta_index = src._fp8_meta_index - src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] - dst_amax: torch.Tensor - if amax is None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] - else: - dst_amax = amax - if dst_amax.dim() > 0: - dst_amax = dst_amax[tuple([0] * dst_amax.dim())] - torch.maximum(src_amax, dst_amax, out=dst_amax) - if dst._transpose is not None: - if src._transpose is None: - dst.transpose_2d(force_compute=True, fill_cache=True) - else: - dst._transpose.copy_(src._transpose) - dst._transpose_invalid = False - return self + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self - # Convert QuantizedTensor to plain tensor - if isinstance(src, QuantizedTensor): - return dst.quantize_(src.dequantize()) + def detach(self) -> Float8Tensor: + # pylint: disable=missing-function-docstring + return Float8Tensor.make_like(self) - # Make sure input is in expected format - if src.size() != dst.size(): - src = src.expand(dst.size()) - if not devices_match(src.device, dst.device): - src = src.to(device=dst.device) - if src.dtype not in (torch.float32, torch.bfloat16, torch.float16): - src = src.float() - if not src.is_contiguous(): - src = src.contiguous() + def _create_transpose(self): + data = self._data + if not data.is_contiguous(): + data = data.contiguous() + self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) + self._transpose_invalid = False - # Make sure FP8 scaling factors are in expected format - if scale is not None: - if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: - scale = scale.to(device=dst.device, dtype=torch.float32) - if amax is not None: - while amax.dim() < 2: - amax = amax.unsqueeze(0) - if not devices_match(amax.device, dst.device): - raise ValueError( - f"Invalid device for amax (expected {dst.device}, found {amax.device})" - ) - if amax.dtype != torch.float32: - raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})") - - # Default FP8 scaling factors - fp8_meta = None - if dst._fp8_meta is None: - if scale is None: - scale = dst._scale_inv.reciprocal() - if amax is None: - amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device) + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" + if rowwise_usage: + assert self._data is not None, "Rowwise usage of the tensor was already disabled" else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta = dst._fp8_meta[fp8_meta_key] - - # Check local data - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - - # Perform FP8 cast - if dst._transpose is None: - dst_data = dst._data - if src.dim() != 2: - src = src.view(1, -1) - dst_data = dst_data.view(1, -1) - cast_to_fp8( - src, - fp8_meta, - dst._fp8_meta_index, - dst._fp8_dtype, - out=dst_data, - scale=scale, - amax=amax, - scale_inv=dst._scale_inv, - ) + if not non_tn_fp8_gemm_supported(): + if self._transpose is None or self._transpose_invalid: + self._create_transpose() + self._data = None + if columnwise_usage: + if self._transpose is None or self._transpose_invalid: + assert self._data is not None, "The tensor does not hold any data anymore" + if not non_tn_fp8_gemm_supported(): + self._create_transpose() else: - fp8_cast_transpose_fused( - src.view(-1, src.size(-1)), - fp8_meta, - dst._fp8_meta_index, - dst._fp8_dtype, - cast_out=dst._data, - transpose_out=dst._transpose, - scale=scale, - amax=amax, - scale_inv=dst._scale_inv, - noop_flag=noop_flag, - ) - dst._transpose_invalid = False - - return self - - @classmethod - def to_float8( - cls, - tensor: torch.Tensor, - *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - data: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - data_transpose: Optional[torch.Tensor] = None, - ): - """Construct Float8Tensor from plain PyTorch tensor""" - return _ToFloat8Func.apply( - tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - data, - scale, - amax, - scale_inv, - with_transpose_cache, - data_transpose, - ) - - def detach(self) -> Float8Tensor: - # pylint: disable=missing-function-docstring - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - ) + self._transpose = None + self._transpose_invalid = True def clone(self) -> Float8Tensor: # pylint: disable=missing-function-docstring + assert self._data is not None data = self._data.detach().clone() data_transpose = None if self._transpose is not None: @@ -761,7 +310,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8Tensor: def contiguous( self, - *, memory_format: torch.memory_format = torch.contiguous_format, ) -> Float8Tensor: """Returns tensor with data in provided memory format @@ -769,148 +317,15 @@ def contiguous( Returns `self` if data is already in correct memory format. """ - if self._data.is_contiguous(memory_format=memory_format): + if self._data is not None and self._data.is_contiguous(memory_format=memory_format): return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. - - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated - - """ - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = ( - force_compute - or (self._transpose is None) - or self._transpose_invalid - or (noop_flag is not None) - ) - - # Return cached transpose if possible - if not need_compute: - assert self._transpose is not None - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out: Optional[torch.Tensor] = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Tensor is reshaped as a 2D matrix. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - - """ - if self._transpose is None: - self._transpose = torch.empty( - (self.size(-1), self.numel() // self.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self.quantize_(tensor, noop_flag=noop_flag) - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. - - """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: - """Create `Float8Tensor` with given nominal dtype - - The new tensor has the same underlying FP8 data. + if self._transpose is not None and self._transpose.is_contiguous( + memory_format=memory_format + ): + return self + return Float8Tensor.make_like(tensor=self, data=self._data.contiguous()) - """ - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - dtype=dtype, - ) + # raise ValueError("Float8Tensor does not support different memory formats!") def _reset_caches(self) -> None: """ @@ -919,32 +334,55 @@ def _reset_caches(self) -> None: """ self._transpose_invalid = True + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._data = torch.Tensor() if self._data is not None else None + self._transpose = torch.Tensor() if self._transpose is not None else None + self._transpose_invalid = True + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # Slice op - if func == aten.slice.Tensor: + # View op + if func == aten.view.default: tensor = args[0] data = tensor._data - data_slice = data.__torch_dispatch__( + out_data = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_slice) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if ( + out_transpose_shape[0] != out_shape[-1] + or out_transpose_shape[1:] != out_shape[:-1] + ): + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=False, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) - # View op - if func == aten.view.default: + if func in [aten.slice.Tensor, aten.select.int]: tensor = args[0] data = tensor._data - data_view = data.__torch_dispatch__( + data_slice = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_view) + return Float8Tensor.make_like(tensor, data=data_slice, shape=data_slice.shape) # Related to FSDP2 if func == aten.split.Tensor: @@ -982,8 +420,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.clone.default: return cls.clone(args[0]) if func == torch.ops.aten.copy_.default: - # Implementation in the superclass (QuantizedTensor) returns a proper output - pass + dst, src = args[0], args[1] + # Just copy FP8 attrs if copying between Float8Tensors + if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): + dst._data.copy_(src._data.detach()) + dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) + if src._transpose is not None or dst._transpose is not None: + dst._create_transpose() + return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 warnings.warn( @@ -1002,6 +446,7 @@ def _make_in_reduce_ex( fp8_dtype: TE_DType, fp8_scale_inv: torch.Tensor, dtype: torch.dtype, + shape: torch.shape, ) -> Float8Tensor: """Build Float8Tensor, for use in __reduce__ @@ -1014,13 +459,14 @@ def _make_in_reduce_ex( fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects""" return ( Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), ) def _get_data(self) -> Float8Tensor: @@ -1039,12 +485,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device - # Check whether grad is required - if self.requires_grad != tensor.requires_grad: - self.requires_grad_(requires_grad=tensor.requires_grad) - # Just copy FP8 data if other tensor is Float8Tensor if isinstance(tensor, Float8Tensor): + + # PyTorch tensor attributes if ( # pylint: disable=too-many-boolean-expressions self.size() != tensor.size() or self.stride() != tensor.stride() @@ -1065,57 +509,110 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + + # Float8Tensor attributes self._data = tensor._data - self._fp8_attrs = tensor._fp8_attrs + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._scale_inv = tensor._scale_inv + self._transpose = tensor._transpose + self._transpose_invalid = tensor._transpose_invalid return - # Reallocate FP8 data if needed - if ( - self.size() != tensor.size() - or self.stride() != tensor.stride() - or self.dtype != tensor.dtype - or self.layout != tensor.layout - or not devices_match(self.device, new_device) - ): - self._data = torch.empty_like( - tensor, - dtype=torch.uint8, - device=new_device, - ) - dummy_tensor = torch.Tensor._make_wrapper_subclass( - Float8Tensor, - self._data.size(), - strides=self._data.stride(), - storage_offset=self._data.storage_offset(), - dtype=tensor.dtype, - layout=self._data.layout, - requires_grad=tensor.requires_grad, - device=self._data.device, - ) - # pylint: disable=unnecessary-dunder-call - super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) - if self._transpose is not None: - self._transpose = torch.empty( - (self._data.size(-1), self._data.numel() // self._data.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self._transpose_invalid = True - - # Copy values from other tensor - self.quantize_(tensor) + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) # Cast to FP8 when setting Float8Tensor.data data = property(_get_data, _set_data) - # Accessors for objects in self._fp8_attrs - # Note: We store FP8 attributes in a dictionary so we can share - # them between tensors with the same data, e.g. detached tensors. - # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) - _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) - _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) - _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + shape: Optional[list[int]] = None, + ) -> Float8Tensor: + # pylint: disable=missing-function-docstring + ctx.shape = tensor.shape + if shape is None: + return tensor.detach() + out_data = tensor._data.view(*shape) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + return grad.reshape(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + shape: Tuple[int], + ) -> Float8Tensor: + # pylint: disable=missing-function-docstring + ctx.shape = tensor.shape + if shape is None: + return tensor.detach() + out_data = tensor._data.reshape(*shape) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + return grad.reshape(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py new file mode 100644 index 0000000000..86b13415a1 --- /dev/null +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -0,0 +1,582 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple + +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ..constants import MXFP8_BLOCK_SCALING_SIZE +from ..utils import devices_match, round_up_to_nearest_multiple + +from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc + +aten = torch.ops.aten + + +class MXFP8Quantizer(Quantizer): + """Builder class for FP8 tensors with MX block scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized by + dividing them into groups of 32 elements, each scaled and cast + separately using current data. + + """ + + dtype: TE_DType + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp8_dtype + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, MXFP8Tensor), f"Cannot store quantized MXFP8 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update FP8 dtype + dst._fp8_dtype = self.dtype + + return dst + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> MXFP8Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert ( + shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 + and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 + ), ( + f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" + f" {MXFP8_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty_like(data) + columnwise_scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) + + # Construct FP8 tensor + return MXFP8Tensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # TODO(ksivamani): No calibration needed for mxfp8? + pass + + +class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __repr__(self, *, tensor_contents=None): + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from MXFP8Tensor + + By default the resulting tensor's dtype is the + MXFP8Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromMXFP8Func.apply(self, dtype) + return _FromMXFP8Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return MXFP8Quantizer( + fp8_dtype=self._fp8_dtype, + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> MXFP8Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return MXFP8Tensor.make_like(self) + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """ + For MXFP8, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor." + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + + assert ( + self._columnwise_data is not None and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + def clone(self) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> MXFP8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("MXFP8Tensor does not support different memory formats!") + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return MXFP8Tensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + ) -> MXFP8Tensor: + """Build MXFP8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + MXFP8Tensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + ), + ) + + def _get_data(self) -> MXFP8Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a MXFP8Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + # Just copy FP8 data if other tensor is MXFP8Tensor + if isinstance(tensor, MXFP8Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + MXFP8Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting MXFP8Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the MXFP8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: MXFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != ctx.shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Construct new tensor if shape is provided + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.view(*shape) + if tensor._columnwise_data is not None: + columnwise_shape = [shape[-1]] + list(shape[:-1]) + new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + return MXFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, MXFP8Tensor): + new_data = ( + grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None + ) + if grad._columnwise_data is not None: + new_columnwise_data = grad._columnwise_data.view(ctx.shape[-1], -1) + else: + new_columnwise_data = None + dgrad = MXFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the MXFP8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: MXFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != ctx.shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Construct new tensor if shape is provided + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.reshape(*shape) + if tensor._columnwise_data is not None: + columnwise_shape = [shape[-1]] + list(shape[:-1]) + new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + + return MXFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, MXFP8Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + new_rowwise_data = grad._rowwise_data.view(*ctx.shape) + if grad._columnwise_data is not None: + columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1]) + new_columnwise_data = grad._columnwise_data.view(columnwise_shape) + dgrad = MXFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 550e113389..707382696d 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -5,23 +5,192 @@ """Tensor with quantized data""" from __future__ import annotations -from typing import Optional, Tuple +from typing import Optional, Tuple, Iterable, Any, Dict, Union +import abc +import copy import torch from torch.utils._pytree import tree_map +import transformer_engine_torch as tex + + +def prepare_for_saving( + *tensors, +) -> Tuple[list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], Optional[Any]]: + """Prepare tensors for saving. Needed because save_for_backward accepts only + torch.Tensor/torch.nn.Parameter types, while we want to be able to save + the internal TensorBase types too.""" + # pylint: disable=unidiomatic-typecheck # Using type instead of isinstance to check exact type + tensor_list, tensor_objects_list = [], [] + for tensor in tensors: + if tensor is None: + tensor_list.append(None) + tensor_objects_list.append(None) + elif type(tensor) in (torch.Tensor, torch.nn.Parameter): + tensor_list.append(tensor.data) + tensor_objects_list.append(None) + else: + t, t_obj = tensor.prepare_for_saving() + tensor_list.extend(t) + tensor_objects_list.append(t_obj) + return tensor_list, tensor_objects_list + + +def restore_from_saved( + tensors: list[Optional[Any]], + saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], +) -> list[Optional[Any]]: + """Recombine the tensor data and metadata during backward pass.""" + tensor_objects = [] + for tensor in tensors: + if tensor is None: + tensor_objects.append(saved_tensors[0]) + saved_tensors = saved_tensors[1:] + else: + saved_tensors = tensor.restore_from_saved(saved_tensors) + tensor_objects.append(tensor) + return tensor_objects + + +class Quantizer(abc.ABC): + """Builder class for quantized tensors. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). -class _DequantizeFunc(torch.autograd.Function): - """Autograd function to convert quantized tensor to standard tensor""" + """ + + """Whether to construct quantized tensors with "row-wise usage" + + Hand-wave explanation: Consider the matrix multiplication C = A * + B^T (used in linear forward). Tensor Cores prefer "TN GEMMs" (in + Fortran-style column-major order), so A and B should be in + row-major order. + + """ + rowwise_usage: bool + + """Whether to construct quantized tensors with "column-wise usage" + + Hand-wave explanation: Consider the matrix multiplication C = A^T + * B (used in linear backward wgrad). Tensor Cores prefer "TN + GEMMs" (in Fortran-style column-major order), so A and B should be + in column-major order. + + """ + columnwise_usage: bool + + """Whether to instantiates tensor for purely internal usage + + Internal tensors are storage classes with minimal logic. They have + less overhead than PyTorch tensor sub-classes, but are not + compatible with PyTorch's autograd infrastructure nor PyTorch + operations. + + """ + internal: bool + + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: + self.rowwise_usage = rowwise + self.columnwise_usage = columnwise + self.internal = False + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"rowwise_usage={self.rowwise_usage}, " + f"columnwise_usage={self.columnwise_usage}, " + f"internal={self.internal}, " + ")" + ) + + @abc.abstractmethod + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Quantize tensor in-place""" + + def quantize( + self, + tensor: torch.Tensor, + *, + out: Optional[QuantizedTensor] = None, + ) -> QuantizedTensor: + """Quantize tensor""" + if out is not None: + return self.update_quantized(tensor, out) + if (not self.internal) and torch.is_grad_enabled(): + return _QuantizeFunc.apply(tensor, self) + return _QuantizeFunc.forward(None, tensor, self) + + def multi_quantize(self, list_of_tensors): + """Quantize multiple tensors""" + list_of_output_tensors = [] + for tensor in list_of_tensors: + list_of_output_tensors.append(self.quantize(tensor)) + return list_of_output_tensors + + def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor""" + return self.quantize(tensor) + + @abc.abstractmethod + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Construct quantized tensor with uninitialized data""" + + @abc.abstractmethod + def calibrate(self, tensor: torch.Tensor) -> None: + """Calibrate quantizer state + + Updates quantization state as if quantizing a tensor, but + without actually performing the quantization. + + """ + + def set_usage( + self, + *, + rowwise: Optional[bool] = None, + columnwise: Optional[bool] = None, + ) -> None: + """Set how the quantized tensor is expected to be used + + See documentation for `rowwise_usage` and `columnwise_usage` + variables. + + """ + if rowwise is not None: + self.rowwise_usage = rowwise + if columnwise is not None: + self.columnwise_usage = columnwise + + def copy(self) -> Quantizer: + """Create shallow copy""" + return copy.copy(self) + + +class _QuantizeFunc(torch.autograd.Function): + """Cast to FP8 from other dtype""" @staticmethod def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: QuantizedTensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: torch.Tensor, + quantizer: Quantizer, + ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tensor.dequantize(dtype=dtype) + return tex.quantize(tensor, quantizer) @staticmethod def backward( @@ -29,27 +198,55 @@ def backward( grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision return grad, None class _IdentityFunc(torch.autograd.Function): - """Autograd function to create quantized tensor with same data""" + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ @staticmethod def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused + ctx, tensor: QuantizedTensor, + init_kwargs: Optional[Dict[str, Any]] = None, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tensor.detach() + + # Return input tensor if constructor kwargs are not provided + if init_kwargs is None: + return tensor.detach() + + # Construct new tensor if constructor kwargs are provided + ctx.input_dtype = tensor.dtype + kwargs = tensor.get_metadata() + for key, val in init_kwargs.items(): + kwargs[key] = val + return type(tensor)(tensor.shape, tensor.dtype, **kwargs) @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> torch.Tensor: + def backward(ctx, grad_output): # pylint: disable=missing-function-docstring - return grad + grad_input = grad_output + if grad_input.dtype == ctx.input_dtype: + grad_input = grad_input.detach() + else: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input, None + + +def _stride_from_shape(shape: list[int]): + if len(shape) == 0: + return [] + rstride = [1] + for d in reversed(shape[1:]): + rstride.append(rstride[-1] * d) + return list(reversed(rstride)) class QuantizedTensor(torch.Tensor): @@ -62,6 +259,22 @@ class QuantizedTensor(torch.Tensor): """ + def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): + # We are assuming only contiguous tensors + stride = _stride_from_shape(shape) + instance = torch.Tensor._make_wrapper_subclass( + cls, + shape, + strides=stride, + storage_offset=0, + dtype=dtype, + layout=torch.strided, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + ) + + return instance + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( @@ -85,24 +298,38 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) - def __repr__(self) -> str: + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """Indicate to the tensor how it is going to be used + + This enables optimizations to memory usage in some cases + where forward and backward passes use the tensor in + different directions. + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_usage function" + ) + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + + def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" def float(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.float32) + return self.dequantize(dtype=torch.float32) def bfloat16(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.bfloat16) + return self.dequantize(dtype=torch.bfloat16) def half(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.float16) + return self.dequantize(dtype=torch.float16) - def cpu(self) -> torch.Tensor: + def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self).cpu() + return self.dequantize().cpu(memory_format=memory_format) def expand_as(self, other: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -179,3 +406,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement contiguous function" + ) + + def get_metadata(self) -> Dict[str, Any]: + """Get keyword arguments for quantized tensor constructor + + Contains metadata so that the new quantized tensor has the + same underlying quantized data. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement get_metadata function" + ) + + @classmethod + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. + + """ + shape = shape if shape is not None else tensor.shape + dtype = dtype if dtype is not None else tensor.dtype + kwargs = tensor.get_metadata() + if data is not None: + kwargs["data"] = data + return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) + + def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: + """Create `QuantizedTensor` with given nominal dtype + + The new tensor has the same underlying data. + + """ + return self.__class__.make_like(self, dtype=dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7c3da9a73f..97b1361163 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -267,11 +267,11 @@ def __init__( zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, - ub_bulk_wgrad: bool = True, - ub_bulk_dgrad: bool = True, ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = True, + ub_bulk_wgrad: bool = True, bias: bool = True, activation: str = "gelu", normalization: str = "LayerNorm", diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 63b2f2cfb5..5b1bd82221 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,11 +6,13 @@ from __future__ import annotations import functools import math -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import torch import transformer_engine.pytorch.cpp_extensions as ext +from .tensor.quantized_tensor import QuantizedTensor + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" @@ -27,12 +29,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ - from .float8_tensor import Float8Tensor - for t in tensors: if t is not None: - if isinstance(t, Float8Tensor): - t._data.data = torch.Tensor() + if isinstance(t, QuantizedTensor): + t.clear() else: t.data = torch.Tensor() del t @@ -231,14 +231,15 @@ def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 -def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: - """Assert that tensor dimensions are supported for FP8 TN GEMM""" - # single tensor check so it's clear which tensor is triggering the assertion - assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( - "FP8 execution requires 2D input matrices with " - "height divisible by 8 and width divisible by 16, " - f"but got tensor with dims={list(tensor.size())}" - ) +def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: + """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" + + for tensor in tensors: + assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( + "FP8 execution requires 2D input matrices with " + "height divisible by 8 and width divisible by 16, " + f"but got tensor with dims={list(tensor.size())}" + ) def is_bf16_compatible() -> None: @@ -248,6 +249,13 @@ def is_bf16_compatible() -> None: return torch.cuda.get_device_capability()[0] >= 8 +def non_tn_fp8_gemm_supported() -> bool: + """Checks whether the device supports + non-TN layouts for FP8 GEMMs. + """ + return torch.cuda.get_device_capability() >= (10, 0) + + @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" @@ -305,3 +313,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: index2 = torch.cuda.current_device() return index1 == index2 return device1 == device2 + + +@functools.lru_cache +def get_sm_count() -> int: + """Returns the number of streaming multiprocessors in the current device.""" + return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + + +def round_up_to_nearest_multiple(value, multiple): + """Round up `value` to the next mutiple of `multiple`""" + if multiple == 0: + raise ValueError("multiple cannot be zero.") + return ((value + multiple - 1) // multiple) * multiple From 8dc06e09c6bdd3caf4c928e257da3b29b9b925e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Feb 2025 20:32:35 +0000 Subject: [PATCH 066/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions.h | 3 +-- transformer_engine/pytorch/csrc/extensions/attention.cu | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 33fc291a9d..d29ab52e3f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -52,8 +52,7 @@ std::vector fused_attn_fwd( const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional page_table_k, - const c10::optional page_table_v, + const c10::optional page_table_k, const c10::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 7d080f59f2..9c4285964c 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -96,8 +96,7 @@ std::vector fused_attn_fwd( const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, const c10::optional page_table_k, const c10::optional page_table_v, - py::handle s_quantizer, - py::handle o_quantizer, const c10::optional Bias, + py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; using namespace transformer_engine::pytorch; From 59dcf48a8aa032c30e7dabdca94689ff9cf4383e Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 10 Feb 2025 15:08:05 -0800 Subject: [PATCH 067/239] WIP: minor fix/preparation for inference/cuda graph Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 456 +++++++++++--------- transformer_engine/pytorch/attention.py | 54 ++- transformer_engine/pytorch/graph.py | 41 +- 3 files changed, 308 insertions(+), 243 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index ce40473c51..8250e53953 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -30,77 +30,6 @@ _cuda_rng_state = torch.cuda.get_rng_state() -class Batch(object): - def __init__(self): - self.batch_size = 0 - self.seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu") - self.ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") - self.gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") - self.total_lens = self.ctx_lens + self.gen_lens - self.expected_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") - self.finished = torch.Tensor([]).to(dtype=torch.bool, device="cpu") - self.step_lens_q = torch.Tensor([]).to(dtype=torch.int32, device="cpu") - - def copy(self): - new_batch = Batch() - new_batch.batch_size = self.batch_size - new_batch.seq_ids = self.seq_ids - new_batch.ctx_lens = self.ctx_lens - new_batch.gen_lens = self.gen_lens - new_batch.total_lens = self.total_lens - new_batch.expected_gen_lens = self.expected_gen_lens - new_batch.finished = self.finished - new_batch.step_lens_q = self.step_lens_q - return new_batch - - def print(self, logger, header="current batch:"): - logger.debug(header) - logger.debug(" {:<17s}: {}".format("batch_size", self.batch_size)) - logger.debug(" {:<17s}: {}".format("seq_ids", self.seq_ids.tolist())) - logger.debug(" {:<17s}: {}".format("ctx_lens", self.ctx_lens.tolist())) - logger.debug(" {:<17s}: {}".format("gen_lens", self.gen_lens.tolist())) - logger.debug(" {:<17s}: {}".format("total_lens", self.total_lens.tolist())) - logger.debug(" {:<17s}: {}".format("expected_gen_lens", self.expected_gen_lens.tolist())) - logger.debug(" {:<17s}: {}".format("finished", self.finished.tolist())) - logger.debug(" {:<17s}: {}".format("step_lens_q", self.step_lens_q.tolist())) - - def add_new_seqs(self, seq_ids, context_lens, expected_gen_lens): - ctx_lens = context_lens[seq_ids] - gen_lens = torch.Tensor([0] * len(seq_ids)).to(dtype=torch.int32, device="cpu") - exp_gen_lens = expected_gen_lens[seq_ids] - finished = torch.Tensor([False] * len(seq_ids)).to(dtype=torch.bool, device="cpu") - - self.batch_size = self.batch_size + len(seq_ids) - self.finished = torch.cat([self.finished, finished], dim=0) - - if len(self.seq_ids) == 0: - self.seq_ids = seq_ids - self.ctx_lens = ctx_lens - self.gen_lens = gen_lens - self.expected_gen_lens = exp_gen_lens - else: - self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0) - self.ctx_lens = torch.cat([self.ctx_lens, ctx_lens], dim=0) - self.gen_lens = torch.cat([self.gen_lens, gen_lens], dim=0) - self.expected_gen_lens = torch.cat([self.expected_gen_lens, exp_gen_lens], dim=0) - self.total_lens = self.ctx_lens + self.gen_lens - self.step_lens_q = torch.cat([self.step_lens_q, ctx_lens], dim=0) - - def remove_finished(self): - self.finished = torch.where(self.gen_lens - self.expected_gen_lens < 0, False, True).to( - dtype=torch.bool, device="cpu" - ) - self.batch_size = self.finished.logical_not().sum().item() - self.seq_ids = self.seq_ids[~self.finished] - self.ctx_lens = self.ctx_lens[~self.finished] - self.gen_lens = self.gen_lens[~self.finished] - self.total_lens = self.total_lens[~self.finished] - self.expected_gen_lens = self.expected_gen_lens[~self.finished] - self.gen_lens = self.gen_lens + 1 - self.total_lens = self.total_lens + 1 - self.step_lens_q = torch.ones([self.batch_size], dtype=torch.int32, device="cpu") - - param_types = [torch.float16] if is_bf16_compatible(): param_types.append(torch.bfloat16) @@ -108,7 +37,7 @@ def remove_finished(self): model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias "infer_0": ModelConfig(4, 16, 16, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=8), - "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), + #"infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -117,13 +46,162 @@ def remove_finished(self): def to_pretty_string(x: torch.Tensor): return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]" +class Simulation: + def __init__( + self, + total_requests: int = 10, + max_seqlen_kv: int = 1024, + context_ratio: float = 0.25, + max_batch_size: int = 5, + poisson_rate: float = None, + ): + self.total_requests = total_requests + self.max_seqlen_kv = max_seqlen_kv + self.context_ratio = context_ratio + self.max_batch_size = max_batch_size + self.poisson_rate = poisson_rate + + # calculate maximum context/generation length + self.max_context_len = int(max_seqlen_kv * context_ratio) + self.max_gen_len = max_seqlen_kv - self.max_context_len + + # simulate sequence ids in monotonically increasing fashion + self.seq_ids = torch.range(0, total_requests-1, dtype=torch.int32, device="cpu") + + # simulate context lengths in Uniform distribution + self.context_lens = torch.randint( + 1, self.max_context_len, [total_requests], dtype=torch.int32, device="cpu" + ) + + # simulate gen lengths in Exponential distribution + gen_dist = Exponential(1 / self.max_gen_len) + gen_lens = gen_dist.sample((total_requests,)) + gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to( + dtype=torch.int32, device="cpu" + ) + self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( + dtype=torch.int32, device="cpu" + ) -@pytest.mark.parametrize("dtype", param_types) + # simulate arrival times in Poisson distribution + if poisson_rate is None: + self.poisson_rate = torch.randint(1, max_batch_size, [1]).item() + interval_dist = Exponential(self.poisson_rate) + arrival_intervals = interval_dist.sample((total_requests,)) + self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") + self.last_arrival = self.arrival_times.max().item() + + # initialize tensors + self.reset() + + def reset(self): + self.t = 0 + self.request_delays = torch.zeros([self.total_requests], dtype=torch.int32, device="cpu") + self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu") + self.serving_times = self.arrival_times + self.complete_times = self.arrival_times + + # time-stepping workflow + # t-1: ... + # compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4], + # batch_size = 3, step_lens = [1, 1, 1] + # increase counter for gen_lens = [3, 10, 5] + # t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15] + # add two new seqs 3 and 4, with ctx lens 10 and 11 + # compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0], + # batch_size = 4, step_lens = [1, 1, 10, 11] + # increase counter for gen_lens = [3, 5, 1, 1] + + # batch info at step t + self.t_seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.t_ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.t_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.t_total_lens = self.t_ctx_lens + self.t_gen_lens #+ self.step_lens + self.t_batch_size = 0 + + # step info from step t-1 to t + self.step_lens = torch.Tensor([]).to(dtype=torch.int32, device="cpu") + + def print(self, logger, label="setup"): + if label == "setup": + logger.info("Simulation:") + logger.info(" {:<33s}: {}".format("total number of requests", self.total_requests)) + logger.info(" {:<33s}: {}".format("max sequence length per request", self.max_seqlen_kv)) + logger.info(" {:<33s}: {}".format("max context lengh", self.max_context_len)) + logger.info(" {:<33s}: {}".format("max generation lengh", self.max_gen_len)) + logger.info(" {:<33s}: {}".format("max batch size per iteration", self.max_batch_size)) + logger.info(" {:<33s}: {}".format("Poisson rate", self.poisson_rate)) + logger.info(" {:<18s}: {}".format("sequence ids", to_pretty_string(self.seq_ids))) + logger.info(" {:<18s}: {}".format("arrival times", to_pretty_string(self.arrival_times))) + logger.info(" {:<18s}: {}".format("context lenghs", to_pretty_string(self.context_lens))) + logger.info(" {:<18s}: {}".format("generation lenghs", to_pretty_string(self.gen_lens))) + if label == "step": + logger.info(f"Step t = {self.t}:") + logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size)) + logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist())) + logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist())) + logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist())) + logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist())) + logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist())) + if label == "summary": + logger.info("Summary:") + logger.info(" {:<18s}: {}".format("total steps taken", self.t)) + logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times))) + logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times))) + logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens))) + logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times))) + + def add_new_seqs(self, new_seq_ids): + # get ctx_lens for new seqs + self.t_seq_ids = torch.cat([self.t_seq_ids, new_seq_ids], dim=0) + self.t_ctx_lens = torch.cat([self.t_ctx_lens, self.context_lens[new_seq_ids]], dim=0) + gen_lens = torch.Tensor([0] * len(new_seq_ids)).to(dtype=torch.int32, device="cpu") + self.t_gen_lens = torch.cat([self.t_gen_lens, gen_lens], dim=0) + # append new seqs' ctx_lens to step_lens + self.step_lens = torch.cat([self.step_lens, self.context_lens[new_seq_ids]], dim=0) + + def remove_finished(self): + # figure out which seqs have finished + finished = torch.where(self.t_gen_lens - self.gen_lens[self.t_seq_ids] < 0, False, True).to( + dtype=torch.bool, device="cpu" + ) + self.t_seq_ids = self.t_seq_ids[~finished] + self.t_ctx_lens = self.t_ctx_lens[~finished] + self.t_gen_lens = self.t_gen_lens[~finished] + # add ones for unfinished seqs to step_lens + self.step_lens = torch.ones([len(self.t_seq_ids)], dtype=torch.int32, device="cpu") + + def step(self, dynamic_fill: bool = True): + # remove finished seqs + if self.t != 0: + self.remove_finished() + # get allowed new seqs + arrived_seq_ids = torch.where(self.arrival_times == self.t, True, False).nonzero().view(-1) + queuing_seq_ids = torch.cat([self.delayed_seq_ids, arrived_seq_ids], dim=0) + if dynamic_fill: + allowed_num_new_seqs = self.max_batch_size - len(self.t_seq_ids) + else: + allowed_num_new_seqs = 0 if len(self.t_seq_ids) else self.max_batch_size + if len(queuing_seq_ids) > allowed_num_new_seqs: + new_seq_ids = queuing_seq_ids[:allowed_num_new_seqs] + self.delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:] + self.request_delays[self.delayed_seq_ids.tolist()] += 1 + else: + new_seq_ids = queuing_seq_ids + self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32) + # add new seqs to batch + self.add_new_seqs(new_seq_ids) + # update batch variables + self.t_batch_size = len(self.t_seq_ids) + self.t_total_lens = self.t_ctx_lens + self.t_gen_lens + + +@pytest.mark.parametrize("dtype", [torch.float16])#param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", qkv_formats) -@pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("is_cuda_graph", [False, True]) +@pytest.mark.parametrize("qkv_format", ['bshd'])#qkv_formats) +@pytest.mark.parametrize("is_paged", [False])#, True]) +@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("is_cuda_graph", [False])#, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() logger = logging.getLogger("test_paged_attn") @@ -131,6 +209,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): config = model_configs_infer[model] layer_number = 1 + # figure out supported backends inference_params_qkv_format = "bshd" if is_paged: qkv_layout = "paged_kv_" + inference_params_qkv_format + "_2" + inference_params_qkv_format @@ -150,68 +229,37 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): pytest.skip("FusedAttention backend is not supported") if backend == "UnfusedAttention" and not unfused_attn_supported: pytest.skip("UnfusedAttention backend is not supported") - os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention")) os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention")) os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) + # set up various parameters total_requests = config.total_requests - # max_batch_size may be smaller than total_requests max_batch_size = config.batch_size - # maximum KV length (context + generation) max_seqlen_kv = config.max_seqlen_kv - - # mask type for inference attn_mask_type = "padding" - - # page size in number of tokens (k cache and v cache are separate) page_size = 256 if backend == "FlashAttention" else 16 - max_seqlen_kv_roundup = max_seqlen_kv if is_paged: + # round up max_seqlen_kv to nearest page size max_seqlen_kv_roundup = int((max_seqlen_kv + page_size - 1) // page_size * page_size) else: + # round up max_seqlen_kv to nearest multiple of 64 max_seqlen_kv_roundup = int((max_seqlen_kv + 63) // 64 * 64) cache_size = max_batch_size * max_seqlen_kv_roundup total_num_pages = int(cache_size / page_size) - context_ratio = 0.25 - gen_ratio = 1 - context_ratio - max_context_len = int(max_seqlen_kv * context_ratio) - max_gen_len = int(max_seqlen_kv * gen_ratio) - - # context lengths in Uniform distribution - context_lens = torch.randint( - 1, max_context_len, [total_requests], dtype=torch.int32, device="cpu" - ) - # generation lengths in Exponential distribution - gen_dist = Exponential(1 / max_gen_len) - gen_lens = gen_dist.sample((total_requests,)) - gen_lens = torch.where(gen_lens > max_gen_len, max_gen_len, gen_lens).to( - dtype=torch.int32, device="cpu" - ) - # arrival times in Poisson distribution - rate = torch.randint(1, max_batch_size, [1]).item() - interval_dist = Exponential(rate) - arrival_intervals = interval_dist.sample((total_requests,)) - arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") - last_arrival = arrival_times.max().item() - - logger.info("Simulation:") - logger.info(f" total num of requests: {total_requests}") - logger.info(f" k/v cache size: {cache_size} tokens") - logger.info(f" is_paged: {is_paged}") - logger.info(f" dtype: {dtype}") - if not is_paged: - logger.info(f" max_batch_size: {max_batch_size}") - logger.info(f" max_seqlen_kv: {max_seqlen_kv}") - else: - logger.info(f" total_num_pages: {total_num_pages}") - logger.info(f" page_size: {page_size}") - logger.info(f" context_lens: {to_pretty_string(context_lens)}") - logger.info(f" expected_gen_lens: {to_pretty_string(gen_lens)}") - logger.info(f" arrival_times: {to_pretty_string(arrival_times)}") + # set up simulation + sim = Simulation( + total_requests=total_requests, + max_seqlen_kv=max_seqlen_kv, + context_ratio=0.25, + max_batch_size=max_batch_size, + poisson_rate=2, + ) + sim.print(logger, label="setup") + # create model and data model = ( DotProductAttention( kv_channels=config.head_dim_qk, @@ -225,7 +273,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): .cuda() .eval() ) - q = 0.1 * torch.randn( (total_requests, max_seqlen_kv, config.num_heads, config.head_dim_qk), dtype=dtype, @@ -242,9 +289,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): device="cuda", ) - logger.info("") + # generate all tokens at once logger.info("=== Generating all tokens at once ===") - request_delays = torch.zeros([total_requests], dtype=torch.int32, device="cpu") full_output = model( query_layer=q, key_layer=k, @@ -253,13 +299,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): attn_mask_type="causal", ) - t = 1 - logger.info(f"total steps taken: {t}") - logger.info(f"arrival_times: {to_pretty_string(arrival_times)}") - logger.info(f"gen_lens: {to_pretty_string(gen_lens)}") - logger.info(f"serving_times: {to_pretty_string(arrival_times + request_delays)}") - - logger.info("") + # generate tokens one at a time logger.info("=== Generating one token at a time ===") inference_params = InferenceParams( max_batch_size=max_batch_size, @@ -278,51 +318,81 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): inference_params.allocate_memory(layer_number) inference_params.print() - request_delays = torch.zeros([total_requests], dtype=torch.int32, device="cpu") - t = 0 - prev = Batch() - delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu") +# def generate_data( +# model_config: ModelConfig, +# dtype: torch.dtype, +# warmup: bool = False, +# ) -> List[torch.Tensor]: +# """Generate synthetic data for dot product attention.""" +# gen_func = torch.ones if warmup else torch.randn +# aa=[ +# gen_func( +# model_config.batch_size, +# model_config.max_seqlen_q, +# model_config.num_heads, +# model_config.head_dim_qk, +# device="cuda", +# #requires_grad=True, +# dtype=dtype, +# ) +# for _ in range(3) +# ] +# #aa.extend([model_config.sequence_length, model_config.sequence_length]) +# return aa +# +# def gen_cu( +# model_config: ModelConfig, +# dtype: torch.dtype, +# ): +# cu_dict = {} +# cu_dict["cu_seqlens_q"] = torch.linspace( 0, +# model_config.batch_size * model_config.max_seqlen_q, +# steps=model_config.batch_size+1, +# device="cuda", +# dtype=torch.int32, +# ) +# cu_dict["cu_seqlens_kv"] = torch.linspace( 0, +# model_config.batch_size * model_config.max_seqlen_kv, +# steps=model_config.batch_size+1, +# device="cuda", +# dtype=torch.int32, +# ) +# cu_dict["max_seqlen_q"] = model_config.max_seqlen_q +# cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv +# return cu_dict +# +# model = make_graphed_callables( +# model, +# generate_data_for_dot_product_attention(model_config, dtype, warmup=True), +# num_warmup_iters=10, +# fp8_enabled=False, +# #sample_kwargs={"qkv_format":"thd"}, +# sample_kwargs=gen_cu(model_config, dtype), +# ) + + # similate step by step + sim.reset() while True: - logger.debug(f"time step {t}") - cur = prev.copy() - if t != 0: - cur.remove_finished() if inference_params.is_paged: inference_params.cache_manager.print_cache() - arrived_seq_ids = torch.where(arrival_times == t, True, False).nonzero().view(-1) - if inference_params.is_paged: - allowed_num_new_seqs = max_batch_size - cur.batch_size - else: - allowed_num_new_seqs = 0 if cur.batch_size > 0 else max_batch_size - queuing_seq_ids = torch.cat([delayed_seq_ids, arrived_seq_ids], dim=0) - logger.debug(f"arrived seq_ids: {to_pretty_string(arrived_seq_ids)}") - logger.debug(f"previously delayed seq_ids: {to_pretty_string(delayed_seq_ids)}") - logger.debug(f"allowed num of new sequences: {allowed_num_new_seqs}") - if len(queuing_seq_ids) > allowed_num_new_seqs: - seq_ids = queuing_seq_ids[:allowed_num_new_seqs] - delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:] - request_delays[delayed_seq_ids.tolist()] += 1 - else: - seq_ids = queuing_seq_ids - delayed_seq_ids = torch.Tensor().to(dtype=torch.int32) - cur.add_new_seqs(seq_ids, context_lens, gen_lens) - cur.print(logger) - if inference_params.is_paged: - inference_params.cache_manager.print_cache() + dynamic_fill = True #inference_params.is_paged + sim.step(dynamic_fill=dynamic_fill) + sim.print(logger, label="step") - if cur.batch_size == 0: + if sim.t_batch_size == 0: # all sequences are finished - if t > last_arrival: + if sim.t > sim.last_arrival: + sim.serving_times = sim.arrival_times + sim.request_delays + sim.complete_times = sim.serving_times + sim.gen_lens break # not finished; run next iteration else: - prev = cur.copy() - t += 1 + sim.t += 1 continue if not is_cuda_graph: - max_seqlen_q_infer = int((cur.step_lens_q.max().item() + 63) // 64 * 64) + max_seqlen_q_infer = int((sim.max_context_len + 63)// 64 * 64) else: max_seqlen_q_infer = max_seqlen_kv_roundup @@ -331,9 +401,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") incremental_k = torch.Tensor().to(dtype=dtype, device="cuda") incremental_v = torch.Tensor().to(dtype=dtype, device="cuda") - for i, seq in enumerate(cur.seq_ids): - start = (cur.total_lens[i] - cur.step_lens_q[i]).item() - end = cur.total_lens[i].item() + for i, seq in enumerate(sim.t_seq_ids): + start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + end = sim.t_total_lens[i].item() incremental_q = torch.cat([incremental_q, q[seq, start:end, :, :]], dim=0) incremental_k = torch.cat( [ @@ -351,7 +421,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): ) else: incremental_q = torch.zeros( - cur.batch_size, + sim.t_batch_size, max_seqlen_q_infer, config.num_heads, config.head_dim_qk, @@ -359,7 +429,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): device="cuda", ) incremental_k = torch.zeros( - cur.batch_size, + sim.t_batch_size, max_seqlen_q_infer, config.num_gqa_groups, config.head_dim_qk, @@ -367,31 +437,31 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): device="cuda", ) incremental_v = torch.zeros( - cur.batch_size, + sim.t_batch_size, max_seqlen_q_infer, config.num_gqa_groups, config.head_dim_v, dtype=dtype, device="cuda", ) - for i, seq in enumerate(cur.seq_ids): - start = (cur.total_lens[i] - cur.step_lens_q[i]).item() - end = cur.total_lens[i].item() - incremental_q[i, : cur.step_lens_q[i], :, :] = q[seq, start:end, :, :] - incremental_k[i, : cur.step_lens_q[i], :, :] = k[seq, start:end, :, :] - incremental_v[i, : cur.step_lens_q[i], :, :] = v[seq, start:end, :, :] + for i, seq in enumerate(sim.t_seq_ids): + start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + end = sim.t_total_lens[i].item() + incremental_q[i, : sim.step_lens[i], :, :] = q[seq, start:end, :, :] + incremental_k[i, : sim.step_lens[i], :, :] = k[seq, start:end, :, :] + incremental_v[i, : sim.step_lens[i], :, :] = v[seq, start:end, :, :] if qkv_format == "sbhd": incremental_q, incremental_k, incremental_v = [ x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] ] - cu_seqlens_q = torch.zeros(cur.batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1 : cur.batch_size + 1] = torch.cumsum(cur.step_lens_q, dim=0) - cu_seqlens_kv = torch.zeros(cur.batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv[1 : cur.batch_size + 1] = torch.cumsum(cur.total_lens, dim=0) + cu_seqlens_q = torch.zeros(sim.t_batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) + cu_seqlens_kv = torch.zeros(sim.t_batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) inference_params.step_dict = OrderedDict( - zip(cur.seq_ids.tolist(), cur.step_lens_q.tolist()) + zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) ) line_output = model( @@ -419,33 +489,31 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): torch.half: 4e-3, torch.bfloat16: 1e-2, } - for i, seq in enumerate(cur.seq_ids): + for i, seq in enumerate(sim.t_seq_ids): if qkv_format == "bshd": torch.testing.assert_close( - full_output[seq, cur.total_lens[i] - 1, :], - line_output[i, cur.step_lens_q[i] - 1, :], + full_output[seq, sim.t_total_lens[i] - 1, :], + line_output[i, sim.step_lens[i] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "sbhd": torch.testing.assert_close( - full_output[seq, cur.total_lens[i] - 1, :], - line_output[cur.step_lens_q[i] - 1, i, :], + full_output[seq, sim.t_total_lens[i] - 1, :], + line_output[sim.step_lens[i] - 1, i, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "thd": torch.testing.assert_close( - full_output[seq, cur.total_lens[i] - 1, :], + full_output[seq, sim.t_total_lens[i] - 1, :], line_output[cu_seqlens_q[i + 1] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) + sim.t += 1 + sim.t_gen_lens = sim.t_gen_lens + 1 - prev = cur.copy() - t += 1 - - logger.info(f"total steps taken: {t}") - logger.info(f"arrival_times: {to_pretty_string(arrival_times)}") - logger.info(f"gen_lens: {to_pretty_string(gen_lens)}") - logger.info(f"serving_times: {to_pretty_string(arrival_times + request_delays)}") + sim.serving_times = sim.arrival_times + sim.request_delays + sim.complete_times = sim.serving_times + sim.gen_lens + sim.print(logger, label="summary") diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4e330a615b..67592f7550 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7400,14 +7400,14 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, - qkv_format: Optional[str] = None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_kv: Optional[torch.Tensor] = None, - cu_seqlens_q_padded: Optional[torch.Tensor] = None, - cu_seqlens_kv_padded: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_kv: Optional[int] = None, + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + qkv_format: str = None, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_kv: torch.Tensor = None, + cu_seqlens_q_padded: torch.Tensor = None, + cu_seqlens_kv_padded: torch.Tensor = None, + max_seqlen_q: int = None, + max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, checkpoint_core_attention: bool = False, @@ -7629,9 +7629,22 @@ def forward( value_layer.shape[-1] == self.hidden_size_per_attention_head_v ), f"Values have head_dim = {value_layer.shape[-1]}, " "but expected head_dim = {self.hidden_size_per_attention_head_v}!" + assert ( + key_layer.shape[-2] == self.num_gqa_groups_per_partition + and value_layer.shape[-2] == self.num_gqa_groups_per_partition + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads! Found {key_layer.shape[-2]} in" + f" key_layer and {value_layer.shape[-2]} in value_layer." + ) if qkv_format is None: qkv_format = self.qkv_format + assert qkv_format in [ + "sbhd", + "bshd", + "thd", + ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" if attn_mask_type is None: attn_mask_type = self.attn_mask_type @@ -7655,19 +7668,13 @@ def forward( graph_safe_rng_available() ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." - if qkv_format is None: - qkv_format = self.qkv_format - - assert qkv_format in [ - "sbhd", - "bshd", - "thd", - ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" - if qkv_format == "thd": assert all( len(x.shape) == 3 for x in (query_layer, key_layer, value_layer) ), "Queries, keys and values must be 3D tensors when qkv_format = thd!" + assert ( + "padding" in attn_mask_type + ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" assert ( cu_seqlens_q is not None and cu_seqlens_kv is not None ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" @@ -7745,19 +7752,6 @@ def forward( # query tensor is now in inference_params.qkv_format qkv_format = target_qkv_format - if qkv_format == "thd": - assert ( - "padding" in attn_mask_type - ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" - assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), ( - "Keys and values must have num_gqa_group =" - f" {self.num_gqa_groups_per_partition} heads! Found {key_layer.shape[-2]} in" - f" key_layer and {value_layer.shape[-2]} in value_layer." - ) - cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 83b316aad4..be889f58e8 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -69,7 +69,6 @@ def _make_graphed_callables( """ Helper method for `make_graphed_callables` """ - if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): raise RuntimeError( "make_graphed_callables does not support the autocast " @@ -255,13 +254,16 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: hook.remove() - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), - only_inputs=True, - allow_unused=allow_unused_input, - ) + if callables[0].training: + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if isinstance(i, torch.Tensor) and i.requires_grad), + grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), + only_inputs=True, + allow_unused=allow_unused_input, + ) + else: + grad_inputs = None del outputs, grad_inputs # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. @@ -366,22 +368,23 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) - with torch.cuda.graph(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, - retain_graph=retain_graph_in_backward, - ) + if callables[0].training: + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if isinstance(i, torch.Tensor) and i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, + ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that # don't require grad. I couldn't think of a slick one-liner for this pattern. static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if arg.requires_grad: + if isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: @@ -422,7 +425,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Copy values from new tensors into static tensors for i in range(len_user_args): - if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + if isinstance(static_input_surface[i], torch.Tensor) and static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) # Replay forward graph From b87e539d0aa2617aeea525c0b4273cfa3ef58b27 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 11 Feb 2025 09:22:08 -0800 Subject: [PATCH 068/239] [JAX] Flax module init with a given dtype (#1472) * flax module to init params with given dtype Signed-off-by: Phuong Nguyen * all tests passed Signed-off-by: Phuong Nguyen * remove unneccessary reshape for kernel Signed-off-by: Phuong Nguyen * remove casting output of dot Signed-off-by: Phuong Nguyen * clean up Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/attention.py | 9 +- transformer_engine/jax/dot.py | 2 +- transformer_engine/jax/flax/module.py | 88 +++++++++++-------- transformer_engine/jax/flax/transformer.py | 16 +++- 4 files changed, 70 insertions(+), 45 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ae3cfddccc..5ec556ab34 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -252,8 +252,13 @@ def abstract( k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + assert ( + q_dtype == k_dtype == v_dtype == bias_dtype + ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}" + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, ( + f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval}," + f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}" + ) batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) diff --git a/transformer_engine/jax/dot.py b/transformer_engine/jax/dot.py index cb8722e089..826b94a983 100644 --- a/transformer_engine/jax/dot.py +++ b/transformer_engine/jax/dot.py @@ -25,7 +25,7 @@ def type_safe_dot_general( """ if fp8_meta_pkg is None: - kernel = jnp.asarray(kernel, x.dtype) + assert x.dtype == kernel.dtype, f"lhs dtype = {x.dtype}, rhs dtype = {kernel.dtype}" return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ()))) amax_list = fp8_meta_pkg.amax_list diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 7aa14fb1ba..4c46eafb4c 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -59,17 +59,13 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga def _create_layernorm_parameters( layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype ): - scale = nn_partitioning.param_with_axes( - "scale", scale_init, shape, jnp.float32, axes=scale_axes - ) - scale = jnp.asarray(scale, dtype) + scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) + scale = scale.astype(dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == "layernorm": - bias = nn_partitioning.param_with_axes( - "ln_bias", bias_init, shape, jnp.float32, axes=bias_axes - ) - bias = jnp.asarray(bias, dtype) + bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) + bias = bias.astype(dtype) else: assert layernorm_type == "rmsnorm" bias = None @@ -280,7 +276,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma + self.scale_init, + self.zero_centered_gamma, ) super().__post_init__() @@ -299,6 +296,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: outputs : jax.numpy.ndarray Output tensors. """ + x = x.astype(self.dtype) features = x.shape[-1] scale, ln_bias = _create_layernorm_parameters( @@ -424,7 +422,9 @@ class DenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -452,14 +452,13 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - - kernel = jnp.reshape(kernel, kernel_shape) + kernel = kernel.astype(self.dtype) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) bias = bias.astype(self.dtype) else: @@ -490,7 +489,7 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - jnp.float32, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) @@ -502,7 +501,7 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=lora_b_kernel_axes, ) lora_b_kernel = lora_b_kernel.astype(self.dtype) @@ -633,9 +632,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma + self.scale_init, + self.zero_centered_gamma, ) super().__post_init__() @@ -665,6 +667,7 @@ def __call__(self, inputs: Array) -> Array: and not self.return_layernorm_output and self.enable_layernorm ) + inputs = inputs.astype(self.dtype) if self.enable_layernorm: inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) @@ -709,10 +712,9 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - - kernel = jnp.reshape(kernel, kernel_shape) + kernel = kernel.astype(self.dtype) contract_ind = tuple(range(0, len(axis))) @@ -755,7 +757,7 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - jnp.float32, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) @@ -767,7 +769,7 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=lora_b_kernel_axes, ) lora_b_kernel = lora_b_kernel.astype(self.dtype) @@ -779,7 +781,7 @@ def __call__(self, inputs: Array) -> Array: bias = None if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) bias = bias.astype(self.dtype) @@ -935,9 +937,12 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma + self.scale_init, + self.zero_centered_gamma, ) super().__post_init__() @@ -970,6 +975,8 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: and self.enable_layernorm ) + inputs = inputs.astype(self.dtype) + gated_act_pool = [ ("gelu", "linear"), ("silu", "linear"), @@ -1033,7 +1040,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): for _ in range(num_kernels): key, init_key = jax_random.split(key) kernels.append(self.kernel_init(init_key, *init_args)) - return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32) + return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype) wi_fp8_meta_pkg = None wo_fp8_meta_pkg = None @@ -1054,10 +1061,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, kernel_1_each_shape, - jnp.float32, + self.dtype, axes=self.kernel_axes_1, ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) + kernel_1 = kernel_1.astype(self.dtype) hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1066,10 +1074,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_kernel", self.kernel_init, kernel_2_param_shape, - jnp.float32, + self.dtype, axes=self.kernel_axes_2, ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) + kernel_2 = kernel_2.astype(self.dtype) contract_ind = tuple(range(0, len(axis))) ffn1_ckpt_name = "ffn1" @@ -1081,13 +1090,13 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: bias_1_shape = intermediate_dim bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1 + "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1 ) bias_1 = bias_1.astype(self.dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2 + "wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2 ) bias_2 = bias_2.astype(self.dtype) else: @@ -1156,7 +1165,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, wi_lora_a_kernel_init_each_shape, - jnp.float32, + self.dtype, axes=wi_lora_a_kernel_axes, ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) @@ -1172,7 +1181,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_lora_b_kernel", nn.initializers.zeros, wi_lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=wi_lora_b_kernel_axes, ) wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) @@ -1189,10 +1198,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_1 = None if self.use_bias: bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1 + "wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1 ) - bias_1 = bias_1.astype(self.dtype) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape + bias_1 = bias_1.astype(self.dtype) x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) @@ -1207,6 +1216,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): z = functools.reduce(operator.mul, activations) # Remove act axis z = jnp.reshape(z, (*z.shape[:-2], -1)) + z = z.astype(self.dtype) + # import pdb; pdb.set_trace() z = nn.Dropout( rate=self.intermediate_dropout_rate, @@ -1215,6 +1226,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): )(z, deterministic=deterministic) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) + z = z.astype(self.dtype) # DenseGeneral 2 out = type_safe_dot_general( @@ -1228,7 +1240,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_a_kernel", self.kernel_init, wo_lora_a_kernel_shape, - jnp.float32, + self.dtype, axes=wo_lora_a_kernel_axes, ) wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) @@ -1239,7 +1251,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_b_kernel", nn.initializers.zeros, wo_lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=wo_lora_b_kernel_axes, ) wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) @@ -1256,7 +1268,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2 = None if self.use_bias: bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2 + "wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2 ) bias_2 = bias_2.astype(self.dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index cf2b13d074..89278f720b 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -976,7 +976,9 @@ def __post_init__(self): ) if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -1198,6 +1200,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): inputs_kv = ln_out key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) + key = key.astype(self.dtype) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) query = checkpoint_name(query, "query_proj") key = checkpoint_name(key, "key_proj") @@ -1437,7 +1440,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): "rel_embedding", self.embedding_init, (self.num_attention_heads, self.num_buckets), - jnp.float32, + self.dtype, axes=self.embedding_axes, ) @@ -1673,10 +1676,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: - self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.mha_kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal" + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1726,6 +1731,9 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ + + inputs = inputs.astype(self.dtype) + assert ( self.layer_type in TransformerLayerType ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." From 09448a9d154b8f05352333b44dae10f19b855ff8 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Tue, 11 Feb 2025 16:33:14 -0800 Subject: [PATCH 069/239] WIP: non-paged Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 179 +++++---- .../common/util/pybind_helper.h | 4 + transformer_engine/pytorch/attention.py | 350 ++++++++++-------- .../pytorch/cpp_extensions/fused_attn.py | 7 + transformer_engine/pytorch/csrc/extensions.h | 10 + .../pytorch/csrc/extensions/attention.cu | 108 ++++++ .../pytorch/csrc/extensions/pybind.cpp | 1 + transformer_engine/pytorch/graph.py | 7 +- .../pytorch/kv_cache_manager_non_paged.py | 202 +++++++--- transformer_engine/pytorch/utils.py | 16 + 10 files changed, 616 insertions(+), 268 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 8250e53953..9b696d9535 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -3,6 +3,7 @@ # See LICENSE for license information. from collections import OrderedDict +from typing import List import os import logging @@ -10,6 +11,7 @@ import torch from torch.distributions import Exponential +from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch.attention import ( DotProductAttention, InferenceParams, @@ -36,7 +38,7 @@ model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "infer_0": ModelConfig(4, 16, 16, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=8), + "infer_0": ModelConfig(4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8), #"infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), } @@ -69,9 +71,10 @@ def __init__( self.seq_ids = torch.range(0, total_requests-1, dtype=torch.int32, device="cpu") # simulate context lengths in Uniform distribution - self.context_lens = torch.randint( - 1, self.max_context_len, [total_requests], dtype=torch.int32, device="cpu" - ) + #self.context_lens = torch.randint( + # 1, self.max_context_len, [total_requests], dtype=torch.int32, device="cpu" + #) + self.context_lens = 10 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -79,16 +82,18 @@ def __init__( gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to( dtype=torch.int32, device="cpu" ) - self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( - dtype=torch.int32, device="cpu" - ) + #self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( + # dtype=torch.int32, device="cpu" + #) + self.gen_lens = 5 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate arrival times in Poisson distribution if poisson_rate is None: self.poisson_rate = torch.randint(1, max_batch_size, [1]).item() interval_dist = Exponential(self.poisson_rate) arrival_intervals = interval_dist.sample((total_requests,)) - self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") + #self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") + self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() # initialize tensors @@ -315,63 +320,78 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): num_heads_q=config.num_heads, head_dim_q=config.head_dim_qk, ) - inference_params.allocate_memory(layer_number) + inference_params.allocate_memory(layer_number, qkv_format) inference_params.print() -# def generate_data( -# model_config: ModelConfig, -# dtype: torch.dtype, -# warmup: bool = False, -# ) -> List[torch.Tensor]: -# """Generate synthetic data for dot product attention.""" -# gen_func = torch.ones if warmup else torch.randn -# aa=[ -# gen_func( -# model_config.batch_size, -# model_config.max_seqlen_q, -# model_config.num_heads, -# model_config.head_dim_qk, -# device="cuda", -# #requires_grad=True, -# dtype=dtype, -# ) -# for _ in range(3) -# ] -# #aa.extend([model_config.sequence_length, model_config.sequence_length]) -# return aa -# -# def gen_cu( -# model_config: ModelConfig, -# dtype: torch.dtype, -# ): -# cu_dict = {} -# cu_dict["cu_seqlens_q"] = torch.linspace( 0, -# model_config.batch_size * model_config.max_seqlen_q, -# steps=model_config.batch_size+1, -# device="cuda", -# dtype=torch.int32, -# ) -# cu_dict["cu_seqlens_kv"] = torch.linspace( 0, -# model_config.batch_size * model_config.max_seqlen_kv, -# steps=model_config.batch_size+1, -# device="cuda", -# dtype=torch.int32, -# ) -# cu_dict["max_seqlen_q"] = model_config.max_seqlen_q -# cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv -# return cu_dict -# + def generate_data( + model_config: ModelConfig, + dtype: torch.dtype, + warmup: bool = False, + ) -> List[torch.Tensor]: + """Generate synthetic data for dot product attention.""" + gen_func = torch.ones if warmup else torch.randn + aa=[ + gen_func( + model_config.batch_size, + 64, #model_config.max_seqlen_q, + model_config.num_heads, + model_config.head_dim_qk, + device="cuda", + #requires_grad=True, + dtype=dtype, + ) + for _ in range(3) + ] + #aa.extend([model_config.sequence_length, model_config.sequence_length]) + return aa + + def gen_cu( + model_config: ModelConfig, + dtype: torch.dtype, + ): + cu_dict = {} + cu_dict["cu_seqlens_q"] = torch.linspace( 0, + model_config.batch_size * 1, #model_config.max_seqlen_q, + #model_config.batch_size * model_config.max_seqlen_q, + steps=model_config.batch_size+1, + device="cuda", + dtype=torch.int32, + ) + cu_dict["cu_seqlens_kv"] = torch.linspace( 0, + model_config.batch_size * 1, #model_config.max_seqlen_kv, + #model_config.batch_size * model_config.max_seqlen_kv, + steps=model_config.batch_size+1, + device="cuda", + dtype=torch.int32, + ) + cu_dict["max_seqlen_q"] = model_config.max_seqlen_q + cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv + cu_dict["inference_params"] = inference_params + cu_dict["attn_mask_type"] = attn_mask_type + #cu_dict["max_seqlen_q"] = max_seqlen_q_infer + #cu_dict["max_seqlen_kv"] = max_seqlen_kv_roundup + cu_dict["qkv_format"] = qkv_format + return cu_dict + +# t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") +# step_lens = torch.ones(max_batch_size, dtype=torch.int32, device="cpu") +# step_dict = OrderedDict( +# zip(t_seq_ids.tolist(), step_lens.tolist()) +# ) +# inference_params.prepare(step_dict) # model = make_graphed_callables( # model, -# generate_data_for_dot_product_attention(model_config, dtype, warmup=True), +# generate_data(config, dtype, warmup=True), # num_warmup_iters=10, # fp8_enabled=False, # #sample_kwargs={"qkv_format":"thd"}, -# sample_kwargs=gen_cu(model_config, dtype), +# sample_kwargs=gen_cu(config, dtype), # ) - +# print('AAAAAAAAAAAAfter graphed') # similate step by step sim.reset() + graphed = False + model_orig = model while True: if inference_params.is_paged: inference_params.cache_manager.print_cache() @@ -460,18 +480,50 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): cu_seqlens_kv = torch.zeros(sim.t_batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) - inference_params.step_dict = OrderedDict( + step_dict = OrderedDict( zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) ) - + inference_params.prepare(step_dict) + + #if sim.step_lens[0] == 1 and not graphed: + # model_graphed = make_graphed_callables( + # model, + # generate_data(config, dtype, warmup=True), + # num_warmup_iters=10, + # fp8_enabled=False, + # #sample_kwargs={"qkv_format":"thd"}, + # sample_kwargs=gen_cu(config, dtype), + # ) + # graphed = True + # print('AAAAAAAAAAAAfter graphed') + if not graphed: + model = make_graphed_callables( + model, + generate_data(config, dtype, warmup=True), + num_warmup_iters=10, + fp8_enabled=False, + #sample_kwargs={"qkv_format":"thd"}, + sample_kwargs=gen_cu(config, dtype), + ) + graphed = True + print('AAAAAAAAAAAAfter graphed') + print('incremental shapes', [x.shape for x in [ incremental_q, incremental_k, incremental_v]]) + + #if sim.step_lens[0] == 1 and graphed: + # model = model_graphed + #else: + # model = model_orig line_output = model( - query_layer=incremental_q, - key_layer=incremental_k, - value_layer=incremental_v, - inference_params=inference_params, - attn_mask_type=attn_mask_type, + #query_layer=incremental_q, + #key_layer=incremental_k, + #value_layer=incremental_v, + incremental_q, + incremental_k, + incremental_v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + inference_params=inference_params, + attn_mask_type=attn_mask_type, max_seqlen_q=max_seqlen_q_infer, max_seqlen_kv=max_seqlen_kv_roundup, qkv_format=qkv_format, @@ -491,6 +543,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): } for i, seq in enumerate(sim.t_seq_ids): if qkv_format == "bshd": + print(i,seq, sim.t_total_lens[i], sim.step_lens[i]) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], line_output[i, sim.step_lens[i] - 1, :], diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index aa181364f2..768fd2797a 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -36,6 +36,10 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 67592f7550..54d14497fb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -56,6 +56,7 @@ split_tensor_along_dim, get_device_compute_capability, get_default_init_method, + StaticBufferAllocator, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -1025,7 +1026,6 @@ def get_attention_backend( available_backends, ) - class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order @@ -1126,21 +1126,23 @@ def __init__( # memory format for the cache; at the moment, only 'bshd' is supported self.qkv_format = "bshd" # layer numbers that we have kv cache for - self.layer_numbers = [] + #self.layer_numbers = [] # sequence ids that are stored in the cache - self.seq_ids = [] + #self.seq_ids = [] # the full sequence lengths for sequences in seq_ids - self.seqlens = [0] * self.max_batch_size + #self.seq_lens = [] #0] * self.max_batch_size + self.sequences = collections.OrderedDict() #zip(self.seq_ids, self.seq_lens)) # the {seq_id: step_len} information for a new inference step # e.g. inference_params.step_dict = {2: 1, 3: 1, 4: 10}, if in this iteration, # we have three sequences in the batch: sequences 2 and 3 are in generation phase # with step_len = 1 and sequence 4 is in context phase with 10 new tokens + #self.step_lens = [] self.step_dict = collections.OrderedDict() # the query buffer when is_cuda_graph = True - if self.is_cuda_graph: - self.q_buffer = {} - self.cu_seqlens_q_buffer = [] - self.cu_seqlens_kv_buffer = [] + #if self.is_cuda_graph: + # self.q_buffer = {} + # self.cu_seqlens_q_buffer = [] + # self.cu_seqlens_kv_buffer = [] def print(self): """Print InferenceParams parameters""" @@ -1156,9 +1158,9 @@ def print(self): logger.debug(" page_size: %s", self.page_size) logger.debug(" num_heads_kv: %s", self.num_heads_kv) logger.debug(" head_dim: k: %s, v: %s", self.head_dim_k, self.head_dim_v) - logger.debug(" layer_numbers: %s", self.layer_numbers) + #logger.debug(" layer_numbers: %s", self.layer_numbers) - def allocate_memory(self, layer_number): + def allocate_memory(self, layer_number: int, qkv_format: str): """ Allocate memory for the KV cache for the layer #layer_number. Both K cache and V cache are in 'bshd' format. @@ -1173,10 +1175,9 @@ def allocate_memory(self, layer_number): - cu_seqlens_q buffer: [max_batch_size + 1] - cu_seqlens_kv buffer: [max_batch_size + 1] """ - self.layer_numbers.append(layer_number) - self.cache_manager.allocate_memory(layer_number) + #self.layer_numbers.append(layer_number) - if self.is_cuda_graph: + if qkv_format == 'thd': #self.is_cuda_graph: self.max_seqlen_q = self.max_seqlen_kv self.q_buffer[layer_number] = torch.zeros( self.max_batch_size, @@ -1186,82 +1187,105 @@ def allocate_memory(self, layer_number): dtype=self.dtype, device=torch.cuda.current_device(), ) - self.cu_seqlens_q_buffer = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - self.cu_seqlens_kv_buffer = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) + self.cache_manager.allocate_memory(layer_number) + self.cu_seqlens_q = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_kv = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) - def reshape_and_copy_q( + def prepare( self, - q: torch.Tensor, - source_qkv_format: str, - target_qkv_format: str, # pylint: disable=unused-argument - layer_number: Optional[int] = None, + step_dict: Dict[List, List], ): - """ - Convert the new query tokens from 'source_qkv_format' to 'target_qkv_format', - so that it is consistent with the KV cache format. At the moment, only 'bshd' format - is supported for target_qkv_format. If is_cuda_graph = True, also copy the new query - tensor to the appropriate q_buffer. - """ + self.sequences = self.cache_manager.prepare(self.sequences, step_dict) + self.step_dict = step_dict + actual_batch_size = len(self.step_dict) seqlens_q = list(self.step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] - batch_wide_max_seqlen_q = int((max(seqlens_q) + 63) // 64 * 64) - if not self.is_cuda_graph: - if source_qkv_format == "bshd": - q = q.contiguous() - if source_qkv_format == "sbhd": - q = q.transpose(0, 1).contiguous() - if source_qkv_format == "thd": - padded_q = torch.zeros( - actual_batch_size, - batch_wide_max_seqlen_q, - q.shape[-2], - q.shape[-1], - dtype=q.dtype, - device="cuda", - ) - for i in range(actual_batch_size): - padded_q[i, : seqlens_q[i], :, :] = q[ - cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : - ] - q = padded_q - - if source_qkv_format in ["bshd", "sbhd"]: - self.max_seqlen_q = q.shape[1] - else: - self.max_seqlen_q = batch_wide_max_seqlen_q - - # bshd: [actual_batch_size, batch_wide_max_seqlen_q, num_heads_q, head_dim_q] - return q - - assert ( - layer_number is not None and layer_number in self.layer_numbers - ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" - q_buffer = self.q_buffer[layer_number] - for i in range(actual_batch_size): - if source_qkv_format == "bshd": - q_buffer[i, : seqlens_q[i], :, :] = q[i, : seqlens_q[i], :, :] - if source_qkv_format == "sbhd": - q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] - if source_qkv_format == "thd": - q_buffer[i, : seqlens_q[i], :, :] = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, :] - q_buffer[i, seqlens_q[i] :, :, :].fill_(0) - - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) - self.cu_seqlens_q_buffer.copy_( + self.seq_lens = list(self.sequences.values()) + self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") ) + cu_seqlens_kv = [0] + [sum(self.seq_lens[:i]) for i in range(1, actual_batch_size + 1)] + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + self.max_batch_size - actual_batch_size + ) + self.cu_seqlens_kv.copy_( + torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") + ) - # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] - return q_buffer +# def reshape_and_copy_q( +# self, +# q: torch.Tensor, +# source_qkv_format: str, +# target_qkv_format: str, # pylint: disable=unused-argument +# layer_number: Optional[int] = None, +# ): +# """ +# Convert the new query tokens from 'source_qkv_format' to 'target_qkv_format', +# so that it is consistent with the KV cache format. At the moment, only 'bshd' format +# is supported for target_qkv_format. If is_cuda_graph = True, also copy the new query +# tensor to the appropriate q_buffer. +# """ +# actual_batch_size = len(self.step_dict) +# seqlens_q = list(self.step_dict.values()) +# cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] +# batch_wide_max_seqlen_q = int((max(seqlens_q) + 63) // 64 * 64) +# if not self.is_cuda_graph: +# if source_qkv_format == "bshd": +# q = q.contiguous() +# if source_qkv_format == "sbhd": +# q = q.transpose(0, 1).contiguous() +# if source_qkv_format == "thd": +# padded_q = torch.zeros( +# actual_batch_size, +# batch_wide_max_seqlen_q, +# q.shape[-2], +# q.shape[-1], +# dtype=q.dtype, +# device="cuda", +# ) +# for i in range(actual_batch_size): +# padded_q[i, : seqlens_q[i], :, :] = q[ +# cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : +# ] +# q = padded_q +# +# if source_qkv_format in ["bshd", "sbhd"]: +# self.max_seqlen_q = q.shape[1] +# else: +# self.max_seqlen_q = batch_wide_max_seqlen_q +# +# # bshd: [actual_batch_size, batch_wide_max_seqlen_q, num_heads_q, head_dim_q] +# return q +# +# assert ( +# layer_number is not None and layer_number in self.layer_numbers +# ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" +# q_buffer = self.q_buffer[layer_number] +# for i in range(actual_batch_size): +# if source_qkv_format == "bshd": +# q_buffer[i, : seqlens_q[i], :, :] = q[i, : seqlens_q[i], :, :] +# if source_qkv_format == "sbhd": +# q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] +# if source_qkv_format == "thd": +# q_buffer[i, : seqlens_q[i], :, :] = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, :] +# q_buffer[i, seqlens_q[i] :, :, :].fill_(0) +# +# cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) +# self.cu_seqlens_q_buffer.copy_( +# torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") +# ) +# +# # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] +# return q_buffer def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): """ @@ -1328,6 +1352,7 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): def update_cache( self, layer_number: int, + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_format: str, @@ -1386,25 +1411,54 @@ def update_cache( page_table: torch.Tensor The page table if is_paged = True; else `None` """ + actual_batch_size = len(self.step_dict) + seqlens_q = list(self.step_dict.values()) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + if qkv_format == "bshd": + q = q.contiguous() + if qkv_format == "sbhd": + q = q.transpose(0, 1).contiguous() + if qkv_format == "thd": + q_buffer = self.q_buffer[layer_number] + for i in range(actual_batch_size): + q_buffer[i, : seqlens_q[i], :, :] = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, :] + q = q_buffer + + #self.page_table = page_table + #self.seq_ids = list(self.cache_manager.sequences.keys()) + #self.seqlens = list(self.cache_manager.sequences.values()) + self.seq_lens = list(self.sequences.values()) + #print('self.sequences',self.sequences) + #print(self.max_batch_size, actual_batch_size) + + #self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( + # torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") + #) + #cu_seqlens_kv = [0] + [sum(self.seq_lens[:i]) for i in range(1, actual_batch_size + 1)] + #cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + # self.max_batch_size - actual_batch_size + #) + #self.cu_seqlens_kv.copy_( + # torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") + #) k_cache, v_cache, page_table = self.cache_manager.step( - layer_number, k, v, self.step_dict, qkv_format + #layer_number, k, v, self.step_dict, qkv_format, self.cu_seqlens_q, self.cu_seqlens_kv, + layer_number, k, v, self.cu_seqlens_q, self.cu_seqlens_kv, qkv_format, ) - self.page_table = page_table - self.seq_ids = list(self.cache_manager.sequences.keys()) - self.seqlens = list(self.cache_manager.sequences.values()) - if self.is_cuda_graph: - actual_batch_size = len(self.seqlens) - cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size + 1)] - cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( - self.max_batch_size - actual_batch_size - ) - self.cu_seqlens_kv_buffer.copy_( - torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") - ) + #if self.is_cuda_graph: + # actual_batch_size = len(self.seqlens) + # cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size + 1)] + # cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + # self.max_batch_size - actual_batch_size + # ) + # self.cu_seqlens_kv_buffer.copy_( + # torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") + # ) # k_cache and v_cache are in InferenceParams.qkv_format format - return k_cache, v_cache, page_table +# return k_cache, v_cache, page_table + return q, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv @torch.no_grad() @@ -7700,58 +7754,6 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - page_table = None - if inference_params is not None: - assert self.layer_number is not None, "Layer number must be set!" - - # remember original format for output purposes - orig_qkv_format = qkv_format - - # convert causal to causal_bottom_right in inference when KV-caching is in use - # so users can run with the same attn_mask_type for training and inference - if attn_mask_type in ["causal", "padding_causal"]: - attn_mask_type = attn_mask_type + "_bottom_right" - - # convert to cross attention type when KV cache is in use - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type - - # force tensors to be contiguous if not already - query_layer, key_layer, value_layer = [ - x.contiguous() if not x.is_contiguous() else x - for x in [query_layer, key_layer, value_layer] - ] - - # reshape the query tensor - # cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but - # flash-attention and unfused attention will need q/k/v in the - # same qkv_format - target_qkv_format = inference_params.qkv_format - query_layer = inference_params.reshape_and_copy_q( - query_layer, qkv_format, target_qkv_format, self.layer_number - ) - - # update KV cache and return the full key/value tensors - # full key/value tensors are in inference_params.qkv_format format - key_layer, value_layer, page_table = inference_params.update_cache( - self.layer_number, - key_layer, - value_layer, - qkv_format, - ) - - # update cu_seqlens tensors - if inference_params.is_cuda_graph: - cu_seqlens_q = inference_params.cu_seqlens_q_buffer - cu_seqlens_kv = inference_params.cu_seqlens_kv_buffer - max_seqlen_q = inference_params.max_seqlen_q - max_seqlen_kv = inference_params.max_seqlen_kv - - # query tensor is now in inference_params.qkv_format - qkv_format = target_qkv_format - cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -7797,6 +7799,62 @@ def forward( key_layer.device, ) + page_table = None + if inference_params is not None: + assert self.layer_number is not None, "Layer number must be set!" + + # remember original format for output purposes + orig_qkv_format = qkv_format + + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + + # convert to cross attention type when KV cache is in use + self.attention_type = "cross" + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type + + # force tensors to be contiguous if not already + query_layer, key_layer, value_layer = [ + x.contiguous() if not x.is_contiguous() else x + for x in [query_layer, key_layer, value_layer] + ] + + # reshape the query tensor + # cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but + # flash-attention and unfused attention will need q/k/v in the + # same qkv_format + #target_qkv_format = inference_params.qkv_format + #query_layer = inference_params.reshape_and_copy_q( + # query_layer, qkv_format, target_qkv_format, self.layer_number + #) + + # update KV cache and return the full key/value tensors + # full key/value tensors are in inference_params.qkv_format format + query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv = inference_params.update_cache( + self.layer_number, + query_layer, + key_layer, + value_layer, + qkv_format, + ) + #print('cu_seqlens_q',cu_seqlens_q) + #print('cu_seqlens_kv',cu_seqlens_kv) + + # update cu_seqlens tensors + #if inference_params.is_cuda_graph: + # cu_seqlens_q = inference_params.cu_seqlens_q_buffer + # cu_seqlens_kv = inference_params.cu_seqlens_kv_buffer + # max_seqlen_q = inference_params.max_seqlen_q + # max_seqlen_kv = inference_params.max_seqlen_kv + + # query tensor is now in inference_params.qkv_format + #qkv_format = target_qkv_format + qkv_format = inference_params.qkv_format + if ( isinstance(query_layer, Float8Tensor) and isinstance(key_layer, Float8Tensor) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 157ac08084..6b194f963a 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -9,6 +9,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import ( NVTE_QKV_Layout, + NVTE_QKV_Format, NVTE_Bias_Type, NVTE_Mask_Type, NVTE_Fused_Attn_Backend, @@ -31,6 +32,12 @@ tex.DType.kInt32: torch.int32, } +QKVFormat = { + "bshd": NVTE_QKV_Format.NVTE_BSHD, + "sbhd": NVTE_QKV_Format.NVTE_SBHD, + "thd": NVTE_QKV_Format.NVTE_THD, +} + QKVLayout = { "sb3hd": NVTE_QKV_Layout.NVTE_SB3HD, "sbh3d": NVTE_QKV_Layout.NVTE_SBH3D, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d29ab52e3f..85497fc1fb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -34,6 +34,16 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T /*************************************************************************************************** * Attention **************************************************************************************************/ +void copy_to_kv_cache_non_paged( + torch::Tensor new_k, torch::Tensor new_v, + torch::Tensor k_cache, torch::Tensor v_cache, + torch::Tensor batch_indices, + torch::Tensor step_lens, + torch::Tensor seq_lens, + NVTE_QKV_Format qkv_format, + int h, int d, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens); NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9c4285964c..513dc78d26 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -13,6 +13,114 @@ using namespace transformer_engine::fused_attn; constexpr int block_size = 512; constexpr int ctas_per_sm = 4; +template +__global__ void copy_to_kv_cache_non_paged_kernel( + scalar_t* new_k, scalar_t* new_v, + scalar_t* k_cache, scalar_t* v_cache, + int* batch_indices, + int* step_lens, + int* seq_lens, + NVTE_QKV_Format qkv_format, + int h, int d, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // batch_indices, step_lens, seq_lens: [b + 1] + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h * d; + int new_token_offset = batch_idx * max_ctx_len * h * d; + int cache_offset = batch_idx * max_seq_len * h * d + (seq_lens[batch_idx] - step_lens[batch_idx]) * h * d; + + scalar_t* new_k_token = new_k + new_token_offset; + scalar_t* k_cache_token = k_cache + cache_offset; + scalar_t* new_v_token = new_v + new_token_offset; + scalar_t* v_cache_token = v_cache + cache_offset; + + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(k_cache_token + i) = *(new_k_token + i); + *(v_cache_token + i) = *(new_v_token + i); + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + for (int j = 0; j < h * d; j ++) { + *(k_cache + (cache_offset + i) * h * d + j) = *(new_k + (i * b + batch_idx) * h * d +j); + *(v_cache + (cache_offset + i) * h * d + j) = *(new_v + (i * b + batch_idx) * h * d +j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + // no padding between sequences in new_k and new_v + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h * d; + int new_token_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + new_token_offset += step_lens[t]; + } + new_token_offset = new_token_offset * h * d; + int cache_offset = batch_idx * max_seq_len * h * d + (seq_lens[batch_idx] - step_lens[batch_idx]) * h * d; + + scalar_t* new_k_token = new_k + new_token_offset; + scalar_t* k_cache_token = k_cache + cache_offset; + scalar_t* new_v_token = new_v + new_token_offset; + scalar_t* v_cache_token = v_cache + cache_offset; + + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(k_cache_token + i) = *(new_k_token + i); + *(v_cache_token + i) = *(new_v_token + i); + } + } + } +} +template +void copy_to_kv_cache_non_paged_launcher( + torch::Tensor new_k, torch::Tensor new_v, + torch::Tensor k_cache, torch::Tensor v_cache, + torch::Tensor batch_indices, + torch::Tensor step_lens, + torch::Tensor seq_lens, + NVTE_QKV_Format qkv_format, + int h, int d, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens) { + copy_to_kv_cache_non_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + batch_indices.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); +} + +void copy_to_kv_cache_non_paged( + torch::Tensor new_k, torch::Tensor new_v, + torch::Tensor k_cache, torch::Tensor v_cache, + torch::Tensor batch_indices, + torch::Tensor step_lens, + torch::Tensor seq_lens, + NVTE_QKV_Format qkv_format, + int h, int d, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens) { + if (k_cache.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + + } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + } else if (k_cache.scalar_type() == at::ScalarType::Float) { + using dtype = float; + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + } else { + NVTE_ERROR("Unsupported dtype.\n"); + } +} // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 442837d767..33ecd3a7cd 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -171,6 +171,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); + m.def("copy_to_kv_cache_non_paged", ©_to_kv_cache_non_paged, "Copy KV to non-paged KV cache"); m.def("fused_attn_fwd", &fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &fused_attn_bwd, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index be889f58e8..92e20b2340 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -246,7 +246,8 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] - for _ in range(num_warmup_iters): + for ii in range(num_warmup_iters): + print("------ warmup ", ii) hooks = [] for module in func.modules(): hook = module.register_forward_hook(hook_fn) @@ -265,6 +266,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument else: grad_inputs = None del outputs, grad_inputs + print("------ end warmup ------") # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): @@ -426,6 +428,9 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Copy values from new tensors into static tensors for i in range(len_user_args): if isinstance(static_input_surface[i], torch.Tensor) and static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + print(i, inputs[i].shape, static_input_surface[i].shape) + if inputs[i].ndim == 1: + print('input', i, inputs[i]) static_input_surface[i].copy_(inputs[i]) # Replay forward graph diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 0f9ce5da66..220d04bb9d 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -4,8 +4,11 @@ """Non-Paged KV Cache Manager.""" from collections import OrderedDict -from typing import Optional +from typing import Optional, Dict, List import torch +#from transformer_engine.pytorch.utils import StaticBufferAllocator +import transformer_engine_torch as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat class NonPagedKVCacheManager: @@ -36,6 +39,13 @@ def __init__( self.sequences = OrderedDict() # KV cache tuple (k_cache, v_cache) self.cache = {} +# self._allocator = StaticBufferAllocator() +# +# def alloc(self, size, dtype, device): +# """ +# Allocated the buffer and works correctly with CUDA Graphs. +# """ +# return self._allocator(size, dtype, device) def allocate_memory(self, layer_number): """Allocate memory for the KV cache""" @@ -57,12 +67,58 @@ def allocate_memory(self, layer_number): ) self.cache[layer_number] = (k_cache, v_cache) + #self.batch_indices = self.alloc( + self.batch_indices = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def prepare( + self, + sequences: Dict[List, List], + step_dict: Dict[List, List], + ): + self.sequences = sequences + #self.step_dict = step_dict + prev_batch_size = len(self.sequences) + batch_size = len(step_dict) + + # Reorder cache + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] + finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] + self.batch_indices.copy_(torch.Tensor(( + unfinished_indices + + finished_indices + + list(range(prev_batch_size, self.max_batch_size)) + )).to(dtype=torch.int32, device="cpu")) + print('self.batch_indices', self.batch_indices) + + # Advance unfinished sequences + for i in unfinished_seqs: + self.sequences[i] += 1 + + # Remove finished sequences + for i in finished_seqs: + self.sequences.pop(i) + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for i in new_seqs: + self.sequences[i] = step_dict[i] + + return self.sequences + def step( self, layer_number, k: torch.Tensor, v: torch.Tensor, - step_dict: OrderedDict, + #step_dict: OrderedDict, + cu_seqlens_q, + cu_seqlens_kv, qkv_format: str, ): """ @@ -90,60 +146,88 @@ def step( The value cache tensor containing previous and the current tokens """ k_cache, v_cache = self.cache[layer_number] - prev_batch_size = len(self.sequences) - batch_size = len(step_dict) - - # Reorder cache - unfinished_seqs = self.sequences.keys() & step_dict.keys() - finished_seqs = self.sequences.keys() - unfinished_seqs - unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] - finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - batch_indices = ( - unfinished_indices - + finished_indices - + list(range(prev_batch_size, self.max_batch_size)) - ) - new_k_cache = k_cache[batch_indices, :] - new_v_cache = v_cache[batch_indices, :] - new_k_cache = new_k_cache.contiguous() - new_v_cache = new_v_cache.contiguous() - - # Advance unfinished sequences - for i in unfinished_seqs: - self.sequences[i] += 1 - - # Remove finished sequences - for i in finished_seqs: - self.sequences.pop(i) - - # Add new sequences - new_seqs = step_dict.keys() - self.sequences.keys() - for i in new_seqs: - self.sequences[i] = step_dict[i] - - # Copy new key/value tokens to cache - step_lens = list(step_dict.values()) - cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] - for i, seq in enumerate(self.sequences): - seq_s = self.sequences[seq] - step_dict[seq] - seq_e = self.sequences[seq] - if qkv_format == "bshd": - new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_dict[seq], :, :] - new_v_cache[i, seq_s:seq_e, :, :] = v[i, : step_dict[seq], :, :] - if qkv_format == "sbhd": - new_k_cache[i, seq_s:seq_e, :, :] = k[: step_dict[seq], i, :, :] - new_v_cache[i, seq_s:seq_e, :, :] = v[: step_dict[seq], i, :, :] - if qkv_format == "thd": - new_k_cache[i, seq_s:seq_e, :, :] = k[cu_seqlens[i] : cu_seqlens[i + 1], :, :] - new_v_cache[i, seq_s:seq_e, :, :] = v[cu_seqlens[i] : cu_seqlens[i + 1], :, :] - self.cache[layer_number] = (new_k_cache, new_v_cache) - - # Return full key/value tensors for attention calculation - if self.is_cuda_graph: - # [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] - return new_k_cache, new_v_cache, None - - # [actual_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] - new_k_cache = new_k_cache[:batch_size].contiguous() - new_v_cache = new_v_cache[:batch_size].contiguous() - return new_k_cache, new_v_cache, None + step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + h=16 + d=64 + b=4 + max_ctx_len=k.shape[1] #64 + max_seq_len=k_cache.shape[1] #64 #128 + max_ctx_tokens=1 + max_tokens=1024 + print('kv shapes ', [x.shape for x in [k, v, k_cache, v_cache]]) + #print('step_lens ', step_lens) + #print('seq_lens ', seq_lens) + #print('self.batch_indices ', self.batch_indices) + tex.copy_to_kv_cache_non_paged(k, v, k_cache, v_cache, self.batch_indices, step_lens, seq_lens, QKVFormat[qkv_format], h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + return k_cache, v_cache, None + +# #prev_batch_size = len(self.sequences) +# #batch_size = len(step_dict) +# batch_size = len(self.sequences) +# +# ## Reorder cache +# #unfinished_seqs = self.sequences.keys() & step_dict.keys() +# #finished_seqs = self.sequences.keys() - unfinished_seqs +# #unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] +# #finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] +# #batch_indices = ( +# # unfinished_indices +# # + finished_indices +# # + list(range(prev_batch_size, self.max_batch_size)) +# #) +# new_k_cache = k_cache[self.batch_indices, :] +# new_v_cache = v_cache[self.batch_indices, :] +# new_k_cache = new_k_cache.contiguous() +# new_v_cache = new_v_cache.contiguous() +# +# ## Advance unfinished sequences +# #for i in unfinished_seqs: +# # self.sequences[i] += 1 +# +# ## Remove finished sequences +# #for i in finished_seqs: +# # self.sequences.pop(i) +# +# ## Add new sequences +# #new_seqs = step_dict.keys() - self.sequences.keys() +# #for i in new_seqs: +# # self.sequences[i] = step_dict[i] +# +# # Copy new key/value tokens to cache +# #step_lens = list(step_dict.values()) +# #cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] +# cu_seqlens = cu_seqlens_q +# step_lens = cu_seqlens[1:] - cu_seqlens[:-1] +# seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] +# #print('self.sequences', self.sequences) +# #print('cu_seqlens_q', cu_seqlens_q) +# #print('cu_seqlens_kv', cu_seqlens_kv) +# #print('step_lens', step_lens) +# for i, seq in enumerate(self.sequences.keys()): +# print('kv cm non-paged i', i, 'seq', seq) +# #seq_s = self.sequences[seq] - step_lens[i] +# #seq_e = self.sequences[seq] +# seq_s = seq_lens[i] - step_lens[i] +# seq_e = seq_lens[i] +# if qkv_format == "bshd": +# print('bshd ', [x.device for x in [new_k_cache, step_lens]]) +# new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_lens[i], :, :] +# new_v_cache[i, seq_s:seq_e, :, :] = v[i, : step_lens[i], :, :] +# if qkv_format == "sbhd": +# new_k_cache[i, seq_s:seq_e, :, :] = k[: step_lens[i], i, :, :] +# new_v_cache[i, seq_s:seq_e, :, :] = v[: step_lens[i], i, :, :] +# if qkv_format == "thd": +# new_k_cache[i, seq_s:seq_e, :, :] = k[cu_seqlens[i] : cu_seqlens[i + 1], :, :] +# new_v_cache[i, seq_s:seq_e, :, :] = v[cu_seqlens[i] : cu_seqlens[i + 1], :, :] +# self.cache[layer_number] = (new_k_cache, new_v_cache) +# +# # Return full key/value tensors for attention calculation +# if self.is_cuda_graph: +# # [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] +# return new_k_cache, new_v_cache, None +# +# # [actual_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] +# new_k_cache = new_k_cache[:batch_size].contiguous() +# new_v_cache = new_v_cache[:batch_size].contiguous() +# return new_k_cache, new_v_cache, None diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5b1bd82221..ee4df62020 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -326,3 +326,19 @@ def round_up_to_nearest_multiple(value, multiple): if multiple == 0: raise ValueError("multiple cannot be zero.") return ((value + multiple - 1) // multiple) * multiple + +class StaticBufferAllocator(torch.nn.Module): + """ + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, + torch API make_graphed_callable() takes care of output of torch modules, + and makes them static. Thus by wrapping allocation of memory into + torch.nn.Module, we can greatly simplify our code. + """ + + # pylint: disable=no-self-use + def forward(self, size, dtype, device): + """ + Return buffer of given size, dtype and device. + """ + return torch.zeros(size, dtype=dtype, device=device) From 49a4535d1addd2c5743a7e280e2f4f2640f0bedf Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Tue, 11 Feb 2025 18:39:18 -0800 Subject: [PATCH 070/239] Add NVTX ranges to categorize execution (#1447) Signed-off-by: Jaemin Choi Signed-off-by: Tim Moon Co-authored-by: Jaemin Choi Co-authored-by: Tim Moon --- transformer_engine/pytorch/attention.py | 18 +++++- .../pytorch/module/layernorm_linear.py | 44 +++++++++++++- transformer_engine/pytorch/module/linear.py | 34 ++++++++++- transformer_engine/pytorch/utils.py | 60 +++++++++++++++++++ 4 files changed, 150 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bf6adc309c..8584431dc2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -23,7 +23,11 @@ import transformer_engine_torch as tex import transformer_engine as te -from transformer_engine.pytorch.utils import get_cudnn_version +from transformer_engine.pytorch.utils import ( + get_cudnn_version, + nvtx_range_pop, + nvtx_range_push, +) from transformer_engine.pytorch.cpp_extensions.fused_attn import ( fused_attn_fwd, fused_attn_bwd, @@ -1834,6 +1838,7 @@ def forward( quantizers, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2756,12 +2761,14 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a @@ -3602,6 +3609,7 @@ def backward(ctx, dout): dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") return ( None, @@ -3688,6 +3696,7 @@ def forward( cp_stream, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -3904,11 +3913,13 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") return out @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) @@ -4092,6 +4103,7 @@ def backward(ctx, dout): dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dk = dk.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( None, @@ -4151,6 +4163,7 @@ def forward( quantizers, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -4403,11 +4416,13 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) ( @@ -4592,6 +4607,7 @@ def backward(ctx, dout): dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) if not ctx.is_input_fp8: dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( None, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 60c73a8d7d..d7a7f20dc4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -24,12 +24,14 @@ ) from ..fp8 import FP8GlobalStateManager from ..utils import ( + assert_dim_for_fp8_exec, + cast_if_needed, + clear_tensor_data, divide, get_default_init_method, init_method_constant, - cast_if_needed, - assert_dim_for_fp8_exec, - clear_tensor_data, + nvtx_range_pop, + nvtx_range_push, requires_grad, ) from ..distributed import ( @@ -112,6 +114,12 @@ def forward( skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -121,10 +129,12 @@ def forward( assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP + nvtx_range_push(f"{nvtx_label}.norm_input_cast") inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + nvtx_range_pop(f"{nvtx_label}.norm_input_cast") tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( @@ -175,6 +185,7 @@ def forward( ) # Apply normalization + nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, @@ -188,9 +199,11 @@ def forward( zero_centered_gamma, ) ln_out_return = ln_out if return_layernorm_output else None + nvtx_range_pop(f"{nvtx_label}.norm") # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") if with_input_all_gather and not ub_overlap_ag_fprop: with_quantized_all_gather = fp8 if return_layernorm_output and return_layernorm_output_gathered: @@ -217,6 +230,7 @@ def forward( elif backward_needs_input: ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) ln_out_total = ln_out + nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype weightmat = weight @@ -275,6 +289,7 @@ def forward( assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." ln_out_total = ub_obj.get_buffer(input_quantizer) + nvtx_range_push(f"{nvtx_label}.gemm") out, *_, rs_out = general_gemm( weightmat, ln_out_total, @@ -287,6 +302,8 @@ def forward( ub_type=ub_type, extra_output=rs_out, ) + nvtx_range_pop(f"{nvtx_label}.gemm") + if not weight.requires_grad: if not return_layernorm_output: ln_out = ln_out_total = None @@ -307,6 +324,7 @@ def forward( # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, @@ -315,6 +333,7 @@ def forward( weightmat if quantized_weight else None, ln_out if weight.requires_grad else None, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") tensors_to_save, tensor_objects = prepare_for_saving( inputmat, @@ -372,10 +391,12 @@ def forward( if ub_overlap_rs_fprop: out = rs_out elif parallel_mode == "row": + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp_shape[1:-1], out_features) @@ -394,6 +415,11 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): if ( ctx.fp8 @@ -433,6 +459,7 @@ def backward( # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( ctx.fsdp_group, ctx.fsdp_shapes, @@ -441,6 +468,7 @@ def backward( weight if ctx.fp8 and ctx.quantized_weight else None, ln_out, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. @@ -515,12 +543,14 @@ def backward( if ctx.fp8: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, quantizer=quantizer, ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: ln_out_total = ln_out @@ -536,6 +566,7 @@ def backward( if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad, *_ = general_gemm( weight, grad_output, @@ -551,12 +582,14 @@ def backward( extra_output=rs_out, bulk_overlap=ctx.ub_bulk_dgrad, ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") # Launch tensor-parallel communication dgrad_work = None if ctx.ub_overlap_rs_dgrad: dgrad = rs_out elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) @@ -567,6 +600,7 @@ def backward( ) else: dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") # Compute grad weight tensor wgrad = None @@ -603,6 +637,7 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") wgrad, grad_bias_, *_, rs_out = general_gemm( ln_out_total, grad_output, @@ -621,6 +656,7 @@ def backward( extra_output=rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): @@ -657,6 +693,7 @@ def backward( # Norm gradient dgamma = None dbeta = None + nvtx_range_push(f"{nvtx_label}.norm") if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, @@ -679,6 +716,7 @@ def backward( ) dgrad = dgrad.reshape(inputmat.size()) dbeta = None + nvtx_range_pop(f"{nvtx_label}.norm") clear_tensor_data(mu) clear_tensor_data(rsigma) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 460ce87bc6..415cc7d9a9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -22,12 +22,14 @@ from ._common import noop_cat, _fix_gathered_fp8_transpose from ..fp8 import FP8GlobalStateManager from ..utils import ( - divide, cast_if_needed, clear_tensor_data, + divide, init_method_constant, - requires_grad, non_tn_fp8_gemm_supported, + nvtx_range_pop, + nvtx_range_push, + requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -100,6 +102,11 @@ def forward( ) -> torch.Tensor: # pylint: disable=missing-function-docstring + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -110,6 +117,7 @@ def forward( # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.input_cast_comm") inputmat = inp inputmat_total = None with_input_all_gather_nccl = ( @@ -153,6 +161,7 @@ def forward( inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat + nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype weightmat = weight @@ -216,6 +225,7 @@ def forward( ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) inputmat_total = ub_obj.get_buffer(input_quantizer) + nvtx_range_push(f"{nvtx_label}.gemm") out, *_, rs_out = general_gemm( weightmat, inputmat_total, @@ -228,6 +238,7 @@ def forward( ub_type=ub_type, extra_output=rs_out, ) + nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: saved_inputmat = None @@ -244,12 +255,14 @@ def forward( # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( @@ -299,10 +312,12 @@ def forward( if ub_overlap_rs_fprop: out = rs_out elif parallel_mode == "row": + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") out = out.view(-1, *inp_shape[1:-1], out_features) return out @@ -311,6 +326,11 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + with torch.cuda.nvtx.range("_Linear_backward"): if ( ctx.fp8 @@ -347,12 +367,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( ctx.fsdp_group, ctx.fsdp_shapes, inputmat, weight_fp8, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -424,12 +446,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.fp8: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, quantizer=quantizer, ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: inputmat_total = inputmat @@ -451,6 +475,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # dgrad GEMM + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -466,11 +491,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], extra_output=rs_out, bulk_overlap=ctx.ub_bulk_dgrad, ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") # Launch tensor-parallel communication if ctx.ub_overlap_rs_dgrad: dgrad = rs_out elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") if ctx.sequence_parallel: dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, @@ -479,6 +506,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) else: dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") # Compute grad weight tensor wgrad = None @@ -515,6 +543,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") wgrad, grad_bias_, _, rs_out = general_gemm( inputmat_total, grad_output, @@ -533,6 +562,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], extra_output=rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5b1bd82221..1922a7e867 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools import math +import os from typing import Any, Callable, List, Optional, Tuple import torch @@ -326,3 +327,62 @@ def round_up_to_nearest_multiple(value, multiple): if multiple == 0: raise ValueError("multiple cannot be zero.") return ((value + multiple - 1) // multiple) * multiple + + +@functools.lru_cache(maxsize=None) +def _nvtx_enabled() -> bool: + """Check if NVTX range profiling is enabled""" + return bool(int(os.getenv("NVTE_NVTX_ENABLED", "0"))) + + +# Messages associated with active NVTX ranges +_nvtx_range_messages: list[str] = [] + + +def nvtx_range_push(msg: str) -> None: + """Push NVTX range onto stack, if NVTX range profiling is enabled + + Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range + profiling. + + Parameters + ---------- + msg: str + Message to associate with range + + """ + if not _nvtx_enabled(): + return + _nvtx_range_messages.append(msg) + torch.cuda.nvtx.range_push(msg) + + +def nvtx_range_pop(msg: Optional[str] = None) -> None: + """Pop NVTX range from stack, if NVTX range profiling is enabled + + Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range + profiling. + + Parameters + ---------- + msg: str, optional + Message associated with range + + """ + + # Return immediately if NVTX range profiling is not enabled + if not _nvtx_enabled(): + return + + # Update list of NVTX range messages and check for consistency + if not _nvtx_range_messages: + raise RuntimeError("Attempted to pop NVTX range from empty stack") + last_msg = _nvtx_range_messages.pop() + if msg is not None and msg != last_msg: + raise ValueError( + f"Attempted to pop NVTX range from stack with msg={msg}, " + f"but last range has msg={last_msg}" + ) + + # Pop NVTX range + torch.cuda.nvtx.range_pop() From 612637cbe54b5d019ab79f02224bef9fed8adbda Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Tue, 11 Feb 2025 22:28:08 -0800 Subject: [PATCH 071/239] WIP: non-paged, bshd/sbhd Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_fused_attn.py | 4 +- tests/pytorch/fused_attn/test_paged_attn.py | 288 ++++++++++-------- transformer_engine/pytorch/attention.py | 269 +++++++++------- .../pytorch/kv_cache_manager_non_paged.py | 42 ++- .../pytorch/kv_cache_manager_paged.py | 3 +- 5 files changed, 358 insertions(+), 248 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 7fa935cb34..775bf1651e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -97,7 +97,8 @@ def __init__( num_layers: int = 1, bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), - total_requests: int = 1, + total_requests: int = None, + max_ctx_len: int = None, ): self.batch_size = batch_size self.num_heads = num_heads @@ -117,6 +118,7 @@ def __init__( self.bias_shape = bias_shape self.window_size = window_size self.total_requests = total_requests + self.max_ctx_len = max_ctx_len @contextmanager diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 9b696d9535..c646af9722 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -6,6 +6,7 @@ from typing import List import os import logging +import math import pytest import torch @@ -38,7 +39,7 @@ model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "infer_0": ModelConfig(4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8), + "infer_0": ModelConfig(4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16), #"infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), } @@ -48,31 +49,33 @@ def to_pretty_string(x: torch.Tensor): return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]" +def round_up(a: int, b: int): + return b * math.ceil(a / b) + class Simulation: def __init__( self, total_requests: int = 10, - max_seqlen_kv: int = 1024, - context_ratio: float = 0.25, + max_seq_len: int = 1024, + max_ctx_len: int = 128, max_batch_size: int = 5, - poisson_rate: float = None, + poisson_rate: float = 1, ): self.total_requests = total_requests - self.max_seqlen_kv = max_seqlen_kv - self.context_ratio = context_ratio + self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.poisson_rate = poisson_rate # calculate maximum context/generation length - self.max_context_len = int(max_seqlen_kv * context_ratio) - self.max_gen_len = max_seqlen_kv - self.max_context_len + self.max_ctx_len = max_ctx_len + self.max_gen_len = max_seq_len - self.max_ctx_len # simulate sequence ids in monotonically increasing fashion self.seq_ids = torch.range(0, total_requests-1, dtype=torch.int32, device="cpu") # simulate context lengths in Uniform distribution #self.context_lens = torch.randint( - # 1, self.max_context_len, [total_requests], dtype=torch.int32, device="cpu" + # 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" #) self.context_lens = 10 * torch.ones(total_requests, dtype=torch.int32, device="cpu") @@ -127,34 +130,35 @@ def reset(self): # step info from step t-1 to t self.step_lens = torch.Tensor([]).to(dtype=torch.int32, device="cpu") - def print(self, logger, label="setup"): - if label == "setup": - logger.info("Simulation:") - logger.info(" {:<33s}: {}".format("total number of requests", self.total_requests)) - logger.info(" {:<33s}: {}".format("max sequence length per request", self.max_seqlen_kv)) - logger.info(" {:<33s}: {}".format("max context lengh", self.max_context_len)) - logger.info(" {:<33s}: {}".format("max generation lengh", self.max_gen_len)) - logger.info(" {:<33s}: {}".format("max batch size per iteration", self.max_batch_size)) - logger.info(" {:<33s}: {}".format("Poisson rate", self.poisson_rate)) - logger.info(" {:<18s}: {}".format("sequence ids", to_pretty_string(self.seq_ids))) - logger.info(" {:<18s}: {}".format("arrival times", to_pretty_string(self.arrival_times))) - logger.info(" {:<18s}: {}".format("context lenghs", to_pretty_string(self.context_lens))) - logger.info(" {:<18s}: {}".format("generation lenghs", to_pretty_string(self.gen_lens))) - if label == "step": - logger.info(f"Step t = {self.t}:") - logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size)) - logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist())) - logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist())) - logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist())) - logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist())) - logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist())) - if label == "summary": - logger.info("Summary:") - logger.info(" {:<18s}: {}".format("total steps taken", self.t)) - logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times))) - logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times))) - logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens))) - logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times))) + def print_setup(self, logger): + logger.info("Simulation:") + logger.info(" {:<31s}: {}".format("total number of requests", self.total_requests)) + logger.info(" {:<31s}: {}".format("max sequence length per request", self.max_seq_len)) + logger.info(" {:<31s}: {}".format("max context length", self.max_ctx_len)) + logger.info(" {:<31s}: {}".format("max generation length", self.max_gen_len)) + logger.info(" {:<31s}: {}".format("max batch size per iteration", self.max_batch_size)) + logger.info(" {:<31s}: {}".format("Poisson rate", self.poisson_rate)) + logger.info(" {:<17s}: {}".format("sequence ids", to_pretty_string(self.seq_ids))) + logger.info(" {:<17s}: {}".format("arrival times", to_pretty_string(self.arrival_times))) + logger.info(" {:<17s}: {}".format("context lengths", to_pretty_string(self.context_lens))) + logger.info(" {:<17s}: {}".format("generation lengths", to_pretty_string(self.gen_lens))) + + def print_step(self, logger): + logger.info(f"Step t = {self.t}:") + logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size)) + logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist())) + logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist())) + logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist())) + logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist())) + logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist())) + + def print_summary(self, logger): + logger.info("Summary:") + logger.info(" {:<18s}: {}".format("total steps taken", self.t)) + logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times))) + logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times))) + logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens))) + logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times))) def add_new_seqs(self, new_seq_ids): # get ctx_lens for new seqs @@ -203,7 +207,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("dtype", [torch.float16])#param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", ['bshd'])#qkv_formats) +@pytest.mark.parametrize("qkv_format", ['thd'])#qkv_formats) @pytest.mark.parametrize("is_paged", [False])#, True]) @pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False])#, True]) @@ -212,6 +216,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): logger = logging.getLogger("test_paged_attn") config = model_configs_infer[model] + num_layers = 2 layer_number = 1 # figure out supported backends @@ -238,63 +243,41 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention")) os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) - # set up various parameters - total_requests = config.total_requests - max_batch_size = config.batch_size - max_seqlen_kv = config.max_seqlen_kv - attn_mask_type = "padding" - page_size = 256 if backend == "FlashAttention" else 16 - max_seqlen_kv_roundup = max_seqlen_kv - if is_paged: - # round up max_seqlen_kv to nearest page size - max_seqlen_kv_roundup = int((max_seqlen_kv + page_size - 1) // page_size * page_size) - else: - # round up max_seqlen_kv to nearest multiple of 64 - max_seqlen_kv_roundup = int((max_seqlen_kv + 63) // 64 * 64) - cache_size = max_batch_size * max_seqlen_kv_roundup - total_num_pages = int(cache_size / page_size) - - # set up simulation - sim = Simulation( - total_requests=total_requests, - max_seqlen_kv=max_seqlen_kv, - context_ratio=0.25, - max_batch_size=max_batch_size, - poisson_rate=2, - ) - sim.print(logger, label="setup") - - # create model and data + # create model + # TODO: multi layers [num_layers] model = ( DotProductAttention( kv_channels=config.head_dim_qk, num_attention_heads=config.num_heads, num_gqa_groups=config.num_gqa_groups, layer_number=layer_number, - attention_dropout=0.0, - attn_mask_type="causal", - qkv_format="bshd", + attention_dropout=config.dropout_p, ) .cuda() .eval() ) + + # generate data for all requests + assert ( + config.max_seqlen_q == config.max_seqlen_kv + ), "This test only simulates max_seqlen_q = max_seqlen_kv." q = 0.1 * torch.randn( - (total_requests, max_seqlen_kv, config.num_heads, config.head_dim_qk), + (config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk), dtype=dtype, device="cuda", ) k = 0.1 * torch.randn( - (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk), + (config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk), dtype=dtype, device="cuda", ) v = 0.1 * torch.randn( - (total_requests, max_seqlen_kv, config.num_gqa_groups, config.head_dim_v), + (config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v), dtype=dtype, device="cuda", ) - # generate all tokens at once + # generate reference results logger.info("=== Generating all tokens at once ===") full_output = model( query_layer=q, @@ -304,11 +287,29 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): attn_mask_type="causal", ) - # generate tokens one at a time + # simulate real-life inference logger.info("=== Generating one token at a time ===") + max_batch_size = config.batch_size + page_size = None + total_num_pages = None + if is_paged: + page_size = 256 if backend == "FlashAttention" else 16 + config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) + total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) + else: + config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64) + sim = Simulation( + total_requests=config.total_requests, + max_seq_len=config.max_seqlen_kv, + max_ctx_len=config.max_ctx_len, + max_batch_size=max_batch_size, + poisson_rate=2, + ) + sim.print_setup(logger) + inference_params = InferenceParams( max_batch_size=max_batch_size, - max_seqlen_kv=max_seqlen_kv_roundup, + max_seqlen_kv=config.max_seqlen_kv, num_heads_kv=config.num_gqa_groups, head_dim_k=config.head_dim_qk, head_dim_v=config.head_dim_v, @@ -316,24 +317,34 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): is_paged=is_paged, page_size=page_size, total_num_pages=total_num_pages, - is_cuda_graph=is_cuda_graph, num_heads_q=config.num_heads, head_dim_q=config.head_dim_qk, + max_ctx_len=config.max_ctx_len, + qkv_format=qkv_format, ) + # TODO: num_layers inference_params.allocate_memory(layer_number, qkv_format) - inference_params.print() + #inference_params.print() def generate_data( model_config: ModelConfig, dtype: torch.dtype, warmup: bool = False, + qkv_format: str = "bshd", ) -> List[torch.Tensor]: """Generate synthetic data for dot product attention.""" gen_func = torch.ones if warmup else torch.randn + if qkv_format == "bshd": + shape = [ model_config.batch_size, model_config.max_ctx_len] + if qkv_format == "sbhd": + shape = [ model_config.max_ctx_len, model_config.batch_size] + if qkv_format == "thd": + shape = [ model_config.batch_size * model_config.max_ctx_len] aa=[ gen_func( - model_config.batch_size, - 64, #model_config.max_seqlen_q, + #model_config.max_ctx_len, + #model_config.batch_size, + *shape, model_config.num_heads, model_config.head_dim_qk, device="cuda", @@ -351,54 +362,59 @@ def gen_cu( ): cu_dict = {} cu_dict["cu_seqlens_q"] = torch.linspace( 0, - model_config.batch_size * 1, #model_config.max_seqlen_q, + model_config.batch_size * model_config.max_ctx_len, #model_config.batch_size * model_config.max_seqlen_q, steps=model_config.batch_size+1, device="cuda", dtype=torch.int32, ) cu_dict["cu_seqlens_kv"] = torch.linspace( 0, - model_config.batch_size * 1, #model_config.max_seqlen_kv, + model_config.batch_size * model_config.max_ctx_len, + #model_config.batch_size * 1, #model_config.max_seqlen_kv, #model_config.batch_size * model_config.max_seqlen_kv, steps=model_config.batch_size+1, device="cuda", dtype=torch.int32, ) - cu_dict["max_seqlen_q"] = model_config.max_seqlen_q - cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv + #cu_dict["max_seqlen_q"] = model_config.max_seqlen_q + #cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv cu_dict["inference_params"] = inference_params - cu_dict["attn_mask_type"] = attn_mask_type - #cu_dict["max_seqlen_q"] = max_seqlen_q_infer - #cu_dict["max_seqlen_kv"] = max_seqlen_kv_roundup + cu_dict["attn_mask_type"] = "padding" #"causal" + # for qkv_format = thd + cu_dict["max_seqlen_q"] = model_config.max_ctx_len #max_seqlen_q_infer + cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv cu_dict["qkv_format"] = qkv_format return cu_dict -# t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") -# step_lens = torch.ones(max_batch_size, dtype=torch.int32, device="cpu") -# step_dict = OrderedDict( -# zip(t_seq_ids.tolist(), step_lens.tolist()) -# ) -# inference_params.prepare(step_dict) -# model = make_graphed_callables( -# model, -# generate_data(config, dtype, warmup=True), -# num_warmup_iters=10, -# fp8_enabled=False, -# #sample_kwargs={"qkv_format":"thd"}, -# sample_kwargs=gen_cu(config, dtype), -# ) -# print('AAAAAAAAAAAAfter graphed') + t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") + step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu") + step_dict = OrderedDict( + zip(t_seq_ids.tolist(), step_lens.tolist()) + ) + inference_params.prepare(step_dict) + if is_cuda_graph: + model = make_graphed_callables( + model, + generate_data(config, dtype, warmup=True, qkv_format=qkv_format), + num_warmup_iters=10, + fp8_enabled=False, + #sample_kwargs={"qkv_format":"thd"}, + sample_kwargs=gen_cu(config, dtype), + ) + print('AAAAAAAAAAAAfter graphed') # similate step by step sim.reset() + inference_params.reset() graphed = False model_orig = model + max_tokens = config.batch_size * config.max_ctx_len while True: if inference_params.is_paged: inference_params.cache_manager.print_cache() dynamic_fill = True #inference_params.is_paged sim.step(dynamic_fill=dynamic_fill) - sim.print(logger, label="step") + sim.print_step(logger) if sim.t_batch_size == 0: # all sequences are finished @@ -411,11 +427,14 @@ def gen_cu( sim.t += 1 continue - if not is_cuda_graph: - max_seqlen_q_infer = int((sim.max_context_len + 63)// 64 * 64) - else: - max_seqlen_q_infer = max_seqlen_kv_roundup + #if not is_cuda_graph: + # max_seqlen_q_infer = int((sim.max_ctx_len + 63)// 64 * 64) + #else: + # max_seqlen_q_infer = max_seqlen_kv_roundup + batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size + max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() #max_seqlen_q_infer, + #max_seqlen_q_infer = sim.max_ctx_len # create incremental input if qkv_format == "thd": incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") @@ -439,26 +458,34 @@ def gen_cu( ], dim=0, ) + if is_cuda_graph: + incremental_q = torch.cat([incremental_q, torch.zeros([max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], dtype=dtype, device=incremental_q.device)], dim=0) + incremental_k = torch.cat([incremental_k, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_k.device)], dim=0) + incremental_v = torch.cat([incremental_v, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_v.device)], dim=0) else: incremental_q = torch.zeros( - sim.t_batch_size, - max_seqlen_q_infer, + batch_size, + #sim.max_ctx_len, #max_seqlen_q_infer, + max_seqlen_q, config.num_heads, config.head_dim_qk, dtype=dtype, device="cuda", ) incremental_k = torch.zeros( - sim.t_batch_size, - max_seqlen_q_infer, + batch_size, + #sim.max_ctx_len, #max_seqlen_q_infer, + max_seqlen_q, config.num_gqa_groups, config.head_dim_qk, dtype=dtype, device="cuda", ) incremental_v = torch.zeros( - sim.t_batch_size, - max_seqlen_q_infer, + #sim.t_batch_size, + batch_size, + #sim.max_ctx_len, #max_seqlen_q_infer, + max_seqlen_q, config.num_gqa_groups, config.head_dim_v, dtype=dtype, @@ -475,9 +502,10 @@ def gen_cu( x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] ] - cu_seqlens_q = torch.zeros(sim.t_batch_size + 1, dtype=torch.int32, device="cuda") + batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) - cu_seqlens_kv = torch.zeros(sim.t_batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) step_dict = OrderedDict( @@ -496,17 +524,17 @@ def gen_cu( # ) # graphed = True # print('AAAAAAAAAAAAfter graphed') - if not graphed: - model = make_graphed_callables( - model, - generate_data(config, dtype, warmup=True), - num_warmup_iters=10, - fp8_enabled=False, - #sample_kwargs={"qkv_format":"thd"}, - sample_kwargs=gen_cu(config, dtype), - ) - graphed = True - print('AAAAAAAAAAAAfter graphed') + #if not graphed: + # model = make_graphed_callables( + # model, + # generate_data(config, dtype, warmup=True), + # num_warmup_iters=10, + # fp8_enabled=False, + # #sample_kwargs={"qkv_format":"thd"}, + # sample_kwargs=gen_cu(config, dtype), + # ) + # graphed = True + # print('AAAAAAAAAAAAfter graphed') print('incremental shapes', [x.shape for x in [ incremental_q, incremental_k, incremental_v]]) #if sim.step_lens[0] == 1 and graphed: @@ -523,11 +551,12 @@ def gen_cu( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, - attn_mask_type=attn_mask_type, - max_seqlen_q=max_seqlen_q_infer, - max_seqlen_kv=max_seqlen_kv_roundup, + attn_mask_type="padding", + max_seqlen_q=max_seqlen_q, #config.max_ctx_len, #max_seqlen_q_infer, + max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, ) + print('llllllllllllllll ', line_output.shape) if backend != "FlashAttention": tols = { @@ -560,6 +589,9 @@ def gen_cu( rtol=tols[dtype], ) if qkv_format == "thd": + print('thd ', seq, sim.t_total_lens[i], cu_seqlens_q[i + 1]) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], line_output[cu_seqlens_q[i + 1] - 1, :], @@ -571,4 +603,4 @@ def gen_cu( sim.serving_times = sim.arrival_times + sim.request_delays sim.complete_times = sim.serving_times + sim.gen_lens - sim.print(logger, label="summary") + sim.print_summary(logger) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 54d14497fb..85bc4fa6ac 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1026,6 +1026,35 @@ def get_attention_backend( available_backends, ) +class KVCacheManager: + """ + KV cache manager. This should be the base class for custom KV cache managers. + """ + def __init__(self, *args, **kwargs): + """Initialize the cache manager""" + self.cache = {} + def allocate_memory(self, layer_number: int): + """Allocate memory for the cache""" + self.cache[layer_number] = (None, None) + def prepare( + self, + sequences: Dict[List, List], + step_dict: Dict[List, List], + ): + """Prepare for step(). Update sequences with step_dict.""" + return sequences + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + qkv_format: str, + ): + """Update the cache with new_k and new_v tokens""" + return *self.cache[layer_number], None + class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order @@ -1061,51 +1090,54 @@ def __init__( num_heads_kv: int, head_dim_k: int, dtype: torch.dtype, - head_dim_v: Optional[int] = None, + head_dim_v: int = None, is_paged: bool = False, - total_num_pages: Optional[int] = None, - page_size: Optional[int] = None, - is_cuda_graph: bool = False, - num_heads_q: Optional[int] = None, - head_dim_q: Optional[int] = None, + total_num_pages: int = None, + page_size: int = None, + num_heads_q: int = None, + head_dim_q: int = None, + max_ctx_len: int = None, + qkv_format: str = "bshd", + cache_manager: KVCacheManager = None, ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv self.num_heads_kv = num_heads_kv self.head_dim_k = head_dim_k - assert dtype in [torch.float32, torch.float16, torch.bfloat16], ( - "Supported InferenceParams.dtype = {torch.float32, torch.float16, torch.bfloat16}." - " Found {dtype}." - ) self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k self.is_paged = is_paged - self.is_cuda_graph = is_cuda_graph - self.page_table = None + #self.page_table = None if not self.is_paged: - self.cache_manager = NonPagedKVCacheManager( + cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager + self.cache_manager = cls( max_batch_size=self.max_batch_size, max_seqlen=self.max_seqlen_kv, num_heads=self.num_heads_kv, head_dim_k=self.head_dim_k, dtype=self.dtype, head_dim_v=self.head_dim_v, - is_cuda_graph=self.is_cuda_graph, ) else: - assert page_size is not None, "page_size is required when is_paged=True!" - assert total_num_pages is not None, "total_num_pages is required when is_paged=True!" + assert page_size is not None, "Paged KV cache requires page_size!" + assert max_seqlen_kv % page_size == 0, "Paged KV cache requires max_seqlen_kv % page_size = 0!" + max_pages_per_seq = max_seqlen_kv // page_size + assert ( + total_num_pages == self.max_batch_size * max_pages_per_seq + ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq!" self.page_size = page_size - self.max_seqlen_kv = ( - self.max_seqlen_kv - if self.max_seqlen_kv >= self.page_size - else int( - (self.max_seqlen_kv + self.page_size - 1) // self.page_size * self.page_size - ) - ) + #self.max_seqlen_kv = ( + # self.max_seqlen_kv + # if self.max_seqlen_kv >= self.page_size + # else int( + # (self.max_seqlen_kv + self.page_size - 1) // self.page_size * self.page_size + # ) + #) + self.max_seqlen_kv = max_seqlen_kv self.total_num_pages = total_num_pages - self.cache_manager = PagedKVCacheManager( + cls = cache_manager if cache_manager is not None else PagedKVCacheManager + self.cache_manager = cls( total_num_pages=self.total_num_pages, page_size=self.page_size, num_heads=self.num_heads_kv, @@ -1114,17 +1146,19 @@ def __init__( max_batch_size=self.max_batch_size, max_seqlen=self.max_seqlen_kv, head_dim_v=self.head_dim_v, - is_cuda_graph=self.is_cuda_graph, ) - if self.is_cuda_graph: - assert num_heads_q is not None, "num_heads_q is required when is_cuda_graph=True!" - assert head_dim_q is not None, "head_dim_q is required when is_cuda_graph=True!" + if qkv_format == "thd": + assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" + assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" + assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" self.num_heads_q = num_heads_q self.head_dim_q = head_dim_q + self.max_ctx_len = max_ctx_len + + # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache + self.inference_qkv_format = "bshd" - # memory format for the cache; at the moment, only 'bshd' is supported - self.qkv_format = "bshd" # layer numbers that we have kv cache for #self.layer_numbers = [] # sequence ids that are stored in the cache @@ -1137,28 +1171,31 @@ def __init__( # we have three sequences in the batch: sequences 2 and 3 are in generation phase # with step_len = 1 and sequence 4 is in context phase with 10 new tokens #self.step_lens = [] + + # TODO: needed? self.step_dict = collections.OrderedDict() + # the query buffer when is_cuda_graph = True #if self.is_cuda_graph: # self.q_buffer = {} # self.cu_seqlens_q_buffer = [] # self.cu_seqlens_kv_buffer = [] - def print(self): - """Print InferenceParams parameters""" - logger = logging.getLogger("InferenceParams") - logger.debug("InferenceParams:") - logger.debug(" dtype: %s", self.dtype) - logger.debug(" is_paged: %s", self.is_paged) - if not self.is_paged: - logger.debug(" max_batch_size: %s", self.max_batch_size) - logger.debug(" max_seqlen_kv: %s", self.max_seqlen_kv) - else: - logger.debug(" total_num_pages: %s", self.total_num_pages) - logger.debug(" page_size: %s", self.page_size) - logger.debug(" num_heads_kv: %s", self.num_heads_kv) - logger.debug(" head_dim: k: %s, v: %s", self.head_dim_k, self.head_dim_v) - #logger.debug(" layer_numbers: %s", self.layer_numbers) + #def print(self): + # """Print InferenceParams parameters""" + # logger = logging.getLogger("InferenceParams") + # logger.debug("InferenceParams:") + # logger.debug(" dtype: %s", self.dtype) + # logger.debug(" is_paged: %s", self.is_paged) + # if not self.is_paged: + # logger.debug(" max_batch_size: %s", self.max_batch_size) + # logger.debug(" max_seqlen_kv: %s", self.max_seqlen_kv) + # else: + # logger.debug(" total_num_pages: %s", self.total_num_pages) + # logger.debug(" page_size: %s", self.page_size) + # logger.debug(" num_heads_kv: %s", self.num_heads_kv) + # logger.debug(" head_dim: k: %s, v: %s", self.head_dim_k, self.head_dim_v) + # #logger.debug(" layer_numbers: %s", self.layer_numbers) def allocate_memory(self, layer_number: int, qkv_format: str): """ @@ -1177,17 +1214,18 @@ def allocate_memory(self, layer_number: int, qkv_format: str): """ #self.layer_numbers.append(layer_number) + self.cache_manager.allocate_memory(layer_number) if qkv_format == 'thd': #self.is_cuda_graph: - self.max_seqlen_q = self.max_seqlen_kv + #self.max_seqlen_q = self.max_seqlen_kv + self.q_buffer = {} self.q_buffer[layer_number] = torch.zeros( self.max_batch_size, - self.max_seqlen_q, + self.max_ctx_len, self.num_heads_q, self.head_dim_q, dtype=self.dtype, device=torch.cuda.current_device(), ) - self.cache_manager.allocate_memory(layer_number) self.cu_seqlens_q = torch.zeros( self.max_batch_size + 1, dtype=torch.int32, @@ -1198,18 +1236,25 @@ def allocate_memory(self, layer_number: int, qkv_format: str): dtype=torch.int32, device=torch.cuda.current_device(), ) + def reset(self): + #self.cu_seqlens_q.fill_(0) + #self.cu_seqlens_kv.fill_(0) + self.sequences = collections.OrderedDict() #zip(self.seq_ids, self.seq_lens)) + self.step_dict = collections.OrderedDict() def prepare( self, step_dict: Dict[List, List], ): self.sequences = self.cache_manager.prepare(self.sequences, step_dict) + self.step_dict = step_dict actual_batch_size = len(self.step_dict) seqlens_q = list(self.step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] self.seq_lens = list(self.sequences.values()) + self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") ) @@ -1411,17 +1456,20 @@ def update_cache( page_table: torch.Tensor The page table if is_paged = True; else `None` """ - actual_batch_size = len(self.step_dict) - seqlens_q = list(self.step_dict.values()) - cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + #actual_batch_size = len(self.step_dict) + #seqlens_q = list(self.step_dict.values()) + #cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + seqlens_q = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + batch_size = len(seqlens_q) if qkv_format == "bshd": q = q.contiguous() if qkv_format == "sbhd": q = q.transpose(0, 1).contiguous() if qkv_format == "thd": q_buffer = self.q_buffer[layer_number] - for i in range(actual_batch_size): - q_buffer[i, : seqlens_q[i], :, :] = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, :] + #for i in range(actual_batch_size): + for i in range(batch_size): + q_buffer[i, : seqlens_q[i], :, :] = q[self.cu_seqlens_q[i] : self.cu_seqlens_q[i + 1], :, :] q = q_buffer #self.page_table = page_table @@ -7754,50 +7802,10 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - cp_size = 1 - if isinstance(self.cp_group, dist_group_type): - cp_size = get_distributed_world_size(self.cp_group) - elif isinstance(self.cp_group, list): - for group in self.cp_group: - cp_size *= get_distributed_world_size(group) - context_parallel = cp_size > 1 - if qkv_format in ["sbhd", "bshd"]: assert all( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" - if qkv_format == "sbhd": - max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[1] - else: - max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[0] - max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size - if cu_seqlens_q is None or cu_seqlens_kv is None: - if "padding" in attn_mask_type: - assert ( - attention_mask is not None - ), "Please provide attention_mask for padding!" - if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) - cu_seqlens_kv = cu_seqlens_q - else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) - else: - cu_seqlens_q = _get_full_cu_seqlens( - batch_size, - max_seqlen_q, - query_layer.device, - ) - cu_seqlens_kv = _get_full_cu_seqlens( - batch_size, - max_seqlen_kv, - key_layer.device, - ) page_table = None if inference_params is not None: @@ -7808,8 +7816,10 @@ def forward( # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference - if attn_mask_type in ["causal", "padding_causal"]: - attn_mask_type = attn_mask_type + "_bottom_right" + if "padding" not in attn_mask_type: + attn_mask_type = "padding_" + attn_mask_type +# if attn_mask_type in ["causal", "padding_causal"]: +# attn_mask_type = attn_mask_type + "_bottom_right" # convert to cross attention type when KV cache is in use self.attention_type = "cross" @@ -7841,8 +7851,8 @@ def forward( value_layer, qkv_format, ) - #print('cu_seqlens_q',cu_seqlens_q) - #print('cu_seqlens_kv',cu_seqlens_kv) + print('cu_seqlens_q',cu_seqlens_q) + print('cu_seqlens_kv',cu_seqlens_kv) # update cu_seqlens tensors #if inference_params.is_cuda_graph: @@ -7853,7 +7863,51 @@ def forward( # query tensor is now in inference_params.qkv_format #qkv_format = target_qkv_format - qkv_format = inference_params.qkv_format + qkv_format = inference_params.inference_qkv_format + + cp_size = 1 + if isinstance(self.cp_group, dist_group_type): + cp_size = get_distributed_world_size(self.cp_group) + elif isinstance(self.cp_group, list): + for group in self.cp_group: + cp_size *= get_distributed_world_size(group) + context_parallel = cp_size > 1 + + if qkv_format in ["sbhd", "bshd"]: + if qkv_format == "sbhd": + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + batch_size = query_layer.shape[1] + else: + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv + batch_size = query_layer.shape[0] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size + if cu_seqlens_q is None or cu_seqlens_kv is None: + if "padding" in attn_mask_type: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + if self.attention_type == "self": + cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_kv = cu_seqlens_q + else: + cu_seqlens_q = get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + else: + cu_seqlens_q = _get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, + ) + cu_seqlens_kv = _get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + print('max_seqlen_q ', max_seqlen_q) + print('max_seqlen_kv ', max_seqlen_kv) if ( isinstance(query_layer, Float8Tensor) @@ -8094,6 +8148,9 @@ def forward( fp8_meta=self.fp8_meta, quantizers=self.quantizers, ) + print('ooooooooooo ',output.shape) + print(output[1,9,:4]) + print(output[1,10,:4]) from .cpu_offload import CPUOffloadEnabled @@ -8141,6 +8198,7 @@ def forward( batch_size = len(inference_params.step_dict) step_lens = list(inference_params.step_dict.values()) max_seqlen_q = max(list(inference_params.step_dict.values())) + print('xxxxxxxxx ', batch_size, step_lens, max_seqlen_q, inference_params.step_dict, orig_qkv_format) if orig_qkv_format == "bshd": output = output[:batch_size, :max_seqlen_q].contiguous() if orig_qkv_format == "sbhd": @@ -8504,19 +8562,6 @@ def __init__( **common_gemm_kwargs, ) - def _allocate_memory( - self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype - ) -> torch.Tensor: - """Allocates memory for KV cache.""" - return torch.empty( - inference_max_sequence_len, - batch_size, - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - dtype=dtype, - device=torch.cuda.current_device(), - ) - def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given @@ -8705,7 +8750,7 @@ def forward( # ================================================= if inference_params is not None and self.layer_number not in inference_params.layer_numbers: - inference_params.allocate_memory(self.layer_number) + inference_params.allocate_memory(self.layer_number, self.qkv_format) # ====================== # Query, Key, and Value diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 220d04bb9d..919ca8eeaf 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -10,8 +10,36 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat +class KVCacheManager: + """ + KV cache manager. This should be the base class for custom KV cache managers. + """ + def __init__(self, *args, **kwargs): + """Initialize the cache manager""" + self.cache = {} + def allocate_memory(self, layer_number: int): + """Allocate memory for the cache""" + self.cache[layer_number] = (None, None) + def prepare( + self, + sequences: Dict[List, List], + step_dict: Dict[List, List], + ): + """Prepare for step(). Update sequences with step_dict.""" + return sequences + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + qkv_format: str, + ): + """Update the cache with new_k and new_v tokens""" + return *self.cache[layer_number], None -class NonPagedKVCacheManager: +class NonPagedKVCacheManager(KVCacheManager): """ The non-paged KV cache manager. """ @@ -24,7 +52,7 @@ def __init__( head_dim_k: int, dtype: torch.dtype, head_dim_v: Optional[int] = None, - is_cuda_graph: bool = False, + #is_cuda_graph: bool = False, ): """Initialize the KV cache""" self.max_batch_size = max_batch_size @@ -33,12 +61,13 @@ def __init__( self.head_dim_k = head_dim_k self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - self.is_cuda_graph = is_cuda_graph + #self.is_cuda_graph = is_cuda_graph # sequences contained in the kv cache, {seq_id: seq_len} - self.sequences = OrderedDict() + #self.sequences = OrderedDict() # KV cache tuple (k_cache, v_cache) self.cache = {} + self.batch_indices = None # self._allocator = StaticBufferAllocator() # # def alloc(self, size, dtype, device): @@ -79,6 +108,7 @@ def prepare( sequences: Dict[List, List], step_dict: Dict[List, List], ): + # TODO: remove self.sequences = sequences #self.step_dict = step_dict prev_batch_size = len(self.sequences) @@ -153,8 +183,8 @@ def step( b=4 max_ctx_len=k.shape[1] #64 max_seq_len=k_cache.shape[1] #64 #128 - max_ctx_tokens=1 - max_tokens=1024 + max_ctx_tokens=k.shape[0] + max_tokens=k_cache.shape[0]*k_cache.shape[1] print('kv shapes ', [x.shape for x in [k, v, k_cache, v_cache]]) #print('step_lens ', step_lens) #print('seq_lens ', seq_lens) diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index 4066538dd7..e26ca22d5f 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -8,6 +8,7 @@ import logging import torch +from transformer_engine.pytorch.kv_cache_manager_non_paged import KVCacheManager class Page: @@ -27,7 +28,7 @@ def deallocate_page(self): self.allocated = False -class PagedKVCacheManager: +class PagedKVCacheManager(KVCacheManager): """ Paged KV cache manager. It supports a set of utilities including adding and removing sequences, and copying new key/value tokens to the cache. Users can overwrite this class From ee4a17de1e0d6b2031fc24b71728f851226920f7 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 12 Feb 2025 06:20:54 -0800 Subject: [PATCH 072/239] Update documentation for 2.0 release (#1479) * Updated docs for TE 2.0 Signed-off-by: Przemek Tredak * Do not expose comm_gemm_overlap and cast_transpose_noop Signed-off-by: Przemek Tredak * Made the figures larger Signed-off-by: Przemek Tredak * Apply suggestions from code review Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak * Update quickstart_utils.py Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Change from review Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Przemyslaw Tredak --------- Signed-off-by: Przemek Tredak Signed-off-by: Przemyslaw Tredak Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- README.rst | 47 ++++++++------ docs/api/c/{layer_norm.rst => fused_rope.rst} | 5 +- docs/api/c/index.rst | 12 ++-- docs/api/c/normalization.rst | 9 +++ docs/api/c/{rmsnorm.rst => padding.rst} | 7 ++- docs/api/c/permutation.rst | 10 +++ docs/api/c/recipe.rst | 10 +++ docs/api/c/swizzle.rst | 10 +++ docs/api/common.rst | 2 + docs/examples/E8M0.png | Bin 0 -> 30953 bytes docs/examples/MXFP8_FP8_comparison_1.png | Bin 0 -> 31195 bytes docs/examples/MXFP8_FP8_comparison_2.png | Bin 0 -> 115749 bytes docs/examples/fp8_primer.ipynb | 59 ++++++++++++++++-- docs/examples/linear_mxfp8.png | Bin 0 -> 49282 bytes docs/examples/quickstart_utils.py | 18 +++--- docs/installation.rst | 9 ++- transformer_engine/common/recipe/__init__.py | 19 ++++-- 17 files changed, 162 insertions(+), 55 deletions(-) rename docs/api/c/{layer_norm.rst => fused_rope.rst} (76%) create mode 100644 docs/api/c/normalization.rst rename docs/api/c/{rmsnorm.rst => padding.rst} (72%) create mode 100644 docs/api/c/permutation.rst create mode 100644 docs/api/c/recipe.rst create mode 100644 docs/api/c/swizzle.rst create mode 100644 docs/examples/E8M0.png create mode 100644 docs/examples/MXFP8_FP8_comparison_1.png create mode 100644 docs/examples/MXFP8_FP8_comparison_2.png create mode 100644 docs/examples/linear_mxfp8.png diff --git a/README.rst b/README.rst index 8fea8c9d94..ace8096c42 100644 --- a/README.rst +++ b/README.rst @@ -33,11 +33,12 @@ What is Transformer Engine? .. overview-begin-marker-do-not-remove Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including -using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower -memory utilization in both training and inference. TE provides a collection of highly optimized -building blocks for popular Transformer architectures and an automatic mixed precision-like API that -can be used seamlessly with your framework-specific code. TE also includes a framework agnostic -C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers. +using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better +performance with lower memory utilization in both training and inference. TE provides a collection +of highly optimized building blocks for popular Transformer architectures and an automatic mixed +precision-like API that can be used seamlessly with your framework-specific code. TE also includes a +framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 +support for Transformers. As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning @@ -51,16 +52,16 @@ not available natively in frameworks today. TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer -layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. -Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly -simplifying mixed precision training for users. +layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 +support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 +training, greatly simplifying mixed precision training for users. Highlights ========== * Easy-to-use modules for building Transformer layers with FP8 support * Optimizations (e.g. fused kernels) for Transformer models -* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs +* Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs * Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later Examples @@ -149,22 +150,22 @@ Installation Pre-requisites ^^^^^^^^^^^^^^^^^^^^ * Linux x86_64 -* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada -* NVIDIA Driver supporting CUDA 12.0 or later -* cuDNN 8.1 or later -* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later. +* CUDA 12.1+ (CUDA 12.8+ for Blackwell) +* NVIDIA Driver supporting CUDA 12.1 or later +* cuDNN 9.3 or later Docker ^^^^^^^^^^^^^^^^^^^^ The quickest way to get started with Transformer Engine is by using Docker images on -`NVIDIA GPU Cloud (NGC) Catalog `_. For example to use the NGC PyTorch container interactively, +`NVIDIA GPU Cloud (NGC) Catalog `_. +For example to use the NGC PyTorch container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3 -Where 23.10 is the container version. For example, 23.10 for the October 2023 release. +Where 25.01 (corresponding to January 2025 release) is the container version. pip ^^^^^^^^^^^^^^^^^^^^ @@ -174,15 +175,21 @@ To install the latest stable version of Transformer Engine, pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). +This will automatically detect if any supported deep learning frameworks are installed and build +Transformer Engine support for them. To explicitly specify frameworks, set the environment variable +NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). -Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g. +Alternatively, the package can be directly installed from +`Transformer Engine's PyPI `_, e.g. .. code-block:: bash pip install transformer_engine[pytorch] -To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be +explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). +Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX +and PyTorch extensions. From source ^^^^^^^^^^^ @@ -190,7 +197,7 @@ From source Compiling with FlashAttention-2 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. +Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance. It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. diff --git a/docs/api/c/layer_norm.rst b/docs/api/c/fused_rope.rst similarity index 76% rename from docs/api/c/layer_norm.rst rename to docs/api/c/fused_rope.rst index 3ac1c6842d..289bb53d9b 100644 --- a/docs/api/c/layer_norm.rst +++ b/docs/api/c/fused_rope.rst @@ -3,7 +3,8 @@ See LICENSE for license information. -layer_norm.h +fused_rope.h ============ -.. doxygenfile:: layer_norm.h +.. doxygenfile:: fused_rope.h + diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index d33e5ab607..7bc864dcc8 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -12,12 +12,16 @@ directly from C/C++, without Python. .. toctree:: :caption: Headers + transformer_engine.h activation.h cast.h - gemm.h fused_attn.h - layer_norm.h - rmsnorm.h + fused_rope.h + gemm.h + normalization.h + padding.h + permutation.h + recipe.h softmax.h - transformer_engine.h + swizzle.h transpose.h diff --git a/docs/api/c/normalization.rst b/docs/api/c/normalization.rst new file mode 100644 index 0000000000..edbea00ac0 --- /dev/null +++ b/docs/api/c/normalization.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +normalization.h +=============== + +.. doxygenfile:: normalization.h diff --git a/docs/api/c/rmsnorm.rst b/docs/api/c/padding.rst similarity index 72% rename from docs/api/c/rmsnorm.rst rename to docs/api/c/padding.rst index d6f378cebc..2141b874d2 100644 --- a/docs/api/c/rmsnorm.rst +++ b/docs/api/c/padding.rst @@ -3,7 +3,8 @@ See LICENSE for license information. -rmsnorm.h -============ +padding.h +========= + +.. doxygenfile:: padding.h -.. doxygenfile:: rmsnorm.h diff --git a/docs/api/c/permutation.rst b/docs/api/c/permutation.rst new file mode 100644 index 0000000000..bad6961621 --- /dev/null +++ b/docs/api/c/permutation.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +permutation.h +============= + +.. doxygenfile:: permutation.h + diff --git a/docs/api/c/recipe.rst b/docs/api/c/recipe.rst new file mode 100644 index 0000000000..7c368f69b6 --- /dev/null +++ b/docs/api/c/recipe.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +recipe.h +======== + +.. doxygenfile:: recipe.h + diff --git a/docs/api/c/swizzle.rst b/docs/api/c/swizzle.rst new file mode 100644 index 0000000000..b2dd8f5977 --- /dev/null +++ b/docs/api/c/swizzle.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +swizzle.h +========= + +.. doxygenfile:: swizzle.h + diff --git a/docs/api/common.rst b/docs/api/common.rst index 5e0a660ae6..95d4b50f30 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -9,3 +9,5 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.Format .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) + +.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) diff --git a/docs/examples/E8M0.png b/docs/examples/E8M0.png new file mode 100644 index 0000000000000000000000000000000000000000..841df25e742860328549006768d860645e659fc4 GIT binary patch literal 30953 zcmeIbc|4Tu|2N#eNm{ugR4SFNk|xGhvP39C$P(EZ*=8`9v?)ST$TBL+WF5kYv9(xA z48}gjlx-T!3+S7)5g^Ef`oXFWde_i^4kZ=lV; zL3qQmWy|=_>YOoJwrnkC*|OzU>(&522~)e_4g9m*$4LA1vi#;m zaUy@Qhndx13z}Zni-QkV>{HEHwXI)Yx+43~YuB5$RWC{gw(Z|@)-7+%LlNgf+lIAT zTKD$YT%1VT|DL~4HE7j}vekSZ8$jY0^|RL8@3cERz_(C*@uBkGw~w@Ld65#5n(@I* zB5S;|Gtq|@%pt?-Y?m=}6}uL=9BxT-54n28!tkX3#HZKILd#dITDN((`t@c1=A!$2 zsHC&Ir?cZO!GCkdw?96&>NsY{ihsDtz4&_FtQM;N^_Fjc_WPG#ufquc=R1~#Mhdoc z_kWo_dGa5o(9&(X@;^^D~PY75hl3)!yT(mSg6~>)9{0z{-a!X ztE<~v^8bGF-`Ao>iw}^!Z{V(7{~%Up_nd#YV{^Cj@IJ}t~!^AA$6OG=IGe!SrSYUO%>3=hTZpgdPRd*?xM%{YpkN7;9{(QUdR1AiIJW2gT)%q#XhzcI`zC1(n)Q3}I1Y-OiPF z&h7p2)ZKT~IZ3C(>zbyHdLQ2KlW_?G+L*Hm_`xcnlY{r=%Pl=CP3*z`$lKqoX0x#^ zR>Xv1{AqT)2Tw$B*+-Y3(fr9^Lb-bYo!AZxVDsggaO$V^f@0tFR^54dH7?8m{j=G8 z0q!QQbZf|@&?pR?TD2ZGrP_M-WxO99PGTM4GV!UL>7mvdirU&uyH2_bxr|o@vqoqO z#dw3H`)BERrO_8@<$7|J#c-;d4Z#Z9<6vpuW<{WfFtI8j{yJ0l%_GzUAN6YrS@||7 z#=5dTNTK_bkU{;*3t!U0R3LKbO9`-$bVWy8)>u9^pMZtY`!3-NFZ*?v#Mb+h(Z2cp z*WU7nr5M{czYREQ6Fk#FjOKiahH8ri40TqqX!h7K?Q+>F=f}Cj=K;LwXHT!C8~=#mRa@j2aLNt|I$Z8PZ=?4@Q9xkQ#l+3|iT4!rG$U+-SO zVxBqcD?xeUgzU-SHmCBnIYkK|=HAAVV)4+dXUS?eajcKX0sYid{j-=-eaPJ=n+V7q zD-69Ssm-kHf*-c}&_`_O33lFY;)|qe)1FMj``~A5c1e8G7(m~_X6p7>JDO?0sP&4nDC0qB`IDpr0dGb6 zkIZYdy2~q)7Gt=<{ZKr)emvlT84MMyJl&#UnUjX%swfnpL~8J$+;h=8y~XmkGN!|) z_ONsc%|Ey1Z1*_#O}95++KpI~S~cl6TQ|`@ZUDNGsl|LRnOfl&8?eBIk%F0hIR5!N z(^SY28wU~|l;H_Wb5$CdhTYOc7$1q1TGvDT-ZX33K zh}49?f(wB*RD)3TjFH(|A)AUFR#jQh1x7GL8K=sBi1b1gZcl}>tdAcMt$|kKM~fiz9DD9LASm_!(YQdj-r>uj^FKU7ZL_U-Dp%2*|p=RvbRbJOHIJzcklaWmDNvFd)J zUzb(v7d2)X&abpTaO@50tc=C5@(^?0!LU@fK32?96`QThW=%KEMPS}j&hbx!X|L>E z?u#Hp!oK@*_3Jg8gj%{Gk7}m(On7+03F+@ss-|?uHJb{8x@OX%Rh0rfO&1IoTApy^ zS$q6j!6n|#^KT1x5F#IXfNrk3bHau(db}ud#RuL7fgNq$umE>R%ijRM1w<3s=7fvtb#Tx zJY}V)>opWQ02U=+eyI1PJ659LqHRAsC!Rmi+9QSI>^pEz(g3pJEGTLS?U0ikR$&kC zPuK&D8y0}8tCo)2Bc&X-lcoA&*g7k6a=ZKU3su-pl)0*@syEYd)cRdMA8)y~I^CpN z%M=sS8hB4pa*BWy+c+d?9+ia;8>c63;rsrK1Aw!eSyQI?m%2uRIZ*Hy&FJpHHgj7jzTnb1vRC6sj%yx`{K1j(6bdCUGs}N zgr@w2hUlZIMDn!6yjal0YmMg5n)arO0(dXO>UvM-kA%Dl|54-A-X_(QD-HT9E4MlS zlsoX&!1%!woy>q7#L>HvkU%+$vz5-`jgUL~h_wk!)*@626W^qm@-WDL0?@NADiI5V zU>rT`Y>Md)sR*e0q5mrWr-1k~(mj|5Cn-BJB4~n!O-0iUlNCl_aYx}K;bESuXv4n7 zCs1h%`P<7~)^8L{2zh4}vI-HRW|io>aCDGis*mz$%O3Xa|ENmi-j#^@0z}xD_sq7~ z%|>q*W(l6DyT--`8}Ze{j@K4Ql*{KBBVEhW5lJ5Yp3Jow)aJJo_D}+Tk&>CvV=q$v zMuO{gxGCOVjoIgeEoxSByOzIk-*@j+ux)i+J}_x_VA3gST_3Bo+`nu~z`)z^&vEN= z`kbprPf&b}IYEzost(qrT-p}(zxd(w zP^@9U(g#XiulpdRA6(K_qX`Q(76#j=Blice9Hay>U;#2!vNrH7Kghkd zf;h8R{4JD~0C|d{I&qIEoV>E;M+(w?YQ=t1VTzd-YJ$;JFu$f_QE{Xubfu*dAy3dy zWY~Y8K67(6zI)}OW;q|j0 zPu^Yd76?|LBUj2?KyVmLwSDm!v>#r5;>fPBGq!mvg2kFjJ6D$Y{A8ZleC{v}SB|tQ z{ld_^&)hh(ts0f2!@ze=>6baB21A)Lq$!zLm?s(SkgD%(4<_DHVwHT+>d?%o!B?Z~ zQy|4C6euIfiDBkiYy;$Y1V}7KzlmyAIcg$Rdr2h}`h9HykCxA?8EdA=%)q9QFCW83 zgUdt`j)xT&u4pth@*EEZANgruKY4jx?iMNo6j1!qcPLCB)LT7%+Vtaog=F!01vZk# zzwz~+E!}^GI^^su^}oZBH?PWAM1b*!Ckc{%NVc53%MT zAhHWVYyROjbr0aS-Zvw^lt;hLDD(nQ8l}x`Jn*|$eShb&&|cuSlG{y&zsmiWv8Dma zu06T)dmi|2o>~GFxo&((sXt!uTNTk-@9Z8uAfNCLx48nfLYLqPwckGDht}TT+R`mn z>vZHFZfgfjB45Tm_vo-Wbhy}JVTZ7RHdP5H|D_2q0IF@ z^pplY&zGYs_YUv7at{^w{(kSN^u@al7Y|(&l?IsvXzR;@(&n=HgkXnr0SxV$~!}LpgN3`s*EvrjD z9gk5>&~|&;p#S=`oU%!J-`$LBUfUC7jSQY!#n&1?t&uX=Lg~A=U9Xs-1v!j-CxME3 zBvN=^`cC-U;mp0XhAah)_6lJG+qFO)6@Bhu-K}* zAi}!^x^Gu(RCham1|ob)pgXc|cG+=_b>l~XFcPyK56t$G%d0!4`@2u$wLb6`?}VXB zcty)uOMhW04Bt(`L|H(jk<`8?2hdF z_7Y}aWX1F8c3wOjV$r^lK1F1iU59*18Y zDh8Cs%{^NB&WE=W9f_9?3pLhI(uVfy?h+=fSt+tON}3&l3J`u);q}_{R$W z+m7^)75;CEg@3H@{|i?5@^Gu}H!Q$^Ol|(xF6Z&huSOEx%99}FS@2JxYcaFXnGQ3U zkWG*mz0%E{c>wGSq}IQrk`LQN4k2|4&wPxmgFfSPeeJRrNb97(rghtuUHA-lWHmp3 zw{?Hif{CkNzm$}e1`i$w@>luOUm8eEp<);wBtOVepaE?y0+u1 zHf*?zV9yX&3teu}2fZXs+yVKTS^1g3fSqeUIFm*#0Yy@+FNb;N7OHG5XuHOrr(1A4 zGxAX?Fcx>DY6y;V+62pS+1rt*9fPh$j)4{vNgYP zpr?Z2uyA*83OL5@9<;xqX~cT%-eE22(cIhW=}L`iZkxoEyKgkBuG@U+!fpi&X)+4t zE6$y(FV*S`Ag2(sP{=@?B6uxDDUxcEA?S`;8qb6d|5N9tI_Jq3CAyMcp)SLj%j_}^ z0Zp3MQOPF;!psfV{AIrV@g41Kg~E%6AI9Ld-s?R&>8(^N|FS?zAo`Jb;YHcFE%?T! zbrl~phBJ>%)KG>Uc3-}E#T&mY;FD}_x7fPU*KbuP?-}~?cEDgjQvk;6Dy|XF(ioF? zVCGfLBGPHiN@KaI8e<@gc*xvm+T36?T`YLI_4eA?77YRY4c=G}8(ymv*k1L3eLPyf z>lhBD#-^OrMv$t4ExA>^s%olHE;XVR#&Sp$CN$paWCZi)&AdXFcj?Rn#m5Jmbbp|h zI0_Pu%CSC5@DQTh3^0n|)iFbvL`5RZFlJO*SNg(0L&-s8m6HR3Dlx1r?cs*P^$%r6 zn-V;RGxM0)@Yr~hj9yPk5VgsSm{vENZPZw*M?Yi0Eu&Ylv_OyW%Dh%?d3$=*bekSl z3SBYvVC6~fB13>O*(7E)Dv&uz2C;cbY6DQ4)*{r+`{eZG58b=u-2VJA|e# z#5WchIYm@FR~kmmdS!}H!_*c(zu!P>XQ&K6G7enZakX|9*slZy?D5y_Hcnk=*Vr`W zVhFc@F`7W0DSrhJF1|2CEZ`*x2R?38Vpq&edrBhx2d*aJ4ZwNe*=~F6TJfN3IXaGp z4~p49*SxCZDgtPymnwKB+P9=EH}pTvd=|~_>YvEjI7f}_;YD6po^=e>qs=r0uLpxwSsFxm^H7=eBM>$H% zpq=Ie&r2cwYK@8t*$lIaG&l~d+QN<$;>^(U=oJ;CGi4)TJzRer9m)!sTF?lAgfIn? z+pK~*)W(t1Ly|k9a~!4TZHFB%^}3E^`laj-O`J@hPir@SbZKwrrJTNYRWpm)7<7=r z;5}0(o)>7GHk>q7b8|(4LexZN;pP!;7R7yM7699e>+EvsD(XNhdKZ;gO|B zDO*~BW_znCU0Uv(=rNx}shNc~36vbrKbD;t8VMd|%EJbvl?$vFRrAQ=fN8d7a8v-(;cA;3o}?+%A)FH$o4+wGq-S1 zsax&CC@V-{eCM8(uEUO^esL>ub{M8BMQqh87Iq6($??lCj*_rxajo5!*SxnDShvgd zd1zcovq)E{S9MayV(l`Jm=L1C3vKnQ4Q-KkK+`{$!BBXIybA&+Jn(= ztTPUl5@zqqH-CXt&U9vk8VFsrc&);ko9cO!(66fSMu@0z4E@ygI7W^2gjy7t5I4%YFok35J(~r` z)>lvXGda<0J6L;{c;O6V(htFnB%-+=JfUg}(Kt#oyPmwARc6i*Q1_Ns)mPxGE6!KS zF>zb+z^$&*9gwlU_LI!6E!tHZI`2waq-cgq8cCPkv9wv~%Kqpp`WF}bq%YU;Rg_Pr zDP(W$%lt8)NYcv7I8(4tyoZSt2(Q1;i9oe>TgT1JlMHh*qg5BaK%e1_{-Uhf_6Q;2Wu+XN7EBen2qOQ*8o)^RPx`GUuHtpTau}N$pl}www@R zUjSetIp~Uzwq;cV;+#s_8}Sgz9v7X+$K;CX4x&$Luy3mgAwxuc{%JL*7X8I~>c+P6 zV9y}GcrgRnFJ^>P1gV&3^l{^Ucwi2AfB;S=PSiaw=kj%=tGeHf#E14sv_Z?B2Yj%~Q}rPKN{< z`%_iHo-=jO-Yco#{TRXtgs1tlp*3?A@h!NUqCVtPkvj;U0fRZ^2-qh!(b97PMD)Tb zelF19#L(3W;d4b?htJIT9YatOy0u%CG1KMC9Zjn~>T+3)TL#}mU?by`N$k(!YHh@T z!5u`>VrR9|$V6#f9Aw$>Jp@-?z$5QuH91zmmozUR+fyio>Dr<9GB+mGuhv*1?&?0| zWc-bc)3EV1OEMz65sjW(RNc&xnjTsW;L zHQ-eelYP|5AW~q2;y>hg&a~Rt51){4e@rre_Yq(K%AeDrQ>O_vf6BQ;u#wwQUjtsfNAmu`(w_Y3TMu!B&< zYGS`NK@KDfXS+BATDTC&F%&%O91GrAge+V+}^d{6B@hf4F$@2;N3A`P!S#H2$_~G98ys z>!(HCy68_C^GRj(4P6+fxb;W$3+*tH?aM(|l{2Rn+ajrI@WpBvdM=dghgj&gC(4MI zdD6ni7f()=R*$6-TkFNDdrh2c^8nvgJv}T3!z~iXFfO%FjViX-1%hfTQ?tTE&xY)p z_w1jNC;CJMkA!QNvmF6LWX#UC!b4Cjni)G*6l(`p%8V;aXLmLmHu}5PKdg8sKmPK( z5@@U(Is8Shu4p~5gzkoBXUD*ydUX%!t{ggImV_nHD(4HB&M3Og)CZ{p(GsLXksf6` zq5Mf=CJWK(W4T4!%_#nkw_#HXyLF{o-K2-JG*dd3@|Ur4-|q%CZCp{S@U#|pt}FC& zHR^FyW)t|KsbA;w&UeRhJ9WEyl+d^g`+LylhNc3atNkyC-Q88ZBC7qPKD)5pCW%P` zQbjv-f7kRFoR!;+6pc64u(|enwPluW&LjM2Ur*OsvY1uiI4n@wpJP$XcBjRw@COF6 ziqXCs%hoEUzfLrhXS(#G>=Qx)ryq1sA{Np!l6vP))>qgE_bOKl1fM8fO5i^nLQPjK zyXfhTlK24A{Rl4a!ET+x;PodcwoYblZy!@ZgosYXhfDz)hZQwaPAzuwd-kH}>Wq#|5g>{&wn(}}cXdccJk zf87(;1w~s7YzkFG%^EMYg}9MX4$39=XRQZzLd_4RB>7q4yr=gM_T0Lpd*Y^()Ky7z zw}I61kZHIzA%o=RNJ6I!euQ~dQ^<0o2JrWMVJW+Pv`Q8=-Wpui7e53VyPq*EZUD1O z(JVu1q6;6P1l8Ow$Cq-q2>qRiXqt2n2)7J4Ip&O+L#2bQES>TmM>$yMm1GwQ0vqmE z#0k1xPkbCNiH$@)ObUt1Loe45zQf+zo14ejcE1#y-M*W~5*yqjg+Z}ATvCOs{XcgO zqI?(%qGb0XWrc7#0B)g?*iRuDrOL4*-Fu9a@yR)We-tPAY6R-ogP(9;i!YvnaoJ4f z;XBjF@#2%P$Nqx)%2;nONx22KdJei4>E`;;YxV}E%nR^^fG5nPk0eCdW;|rpN_@P3exwWPR5iC+bKp?VG7Dez7}HkF{rv&I8co*D zil{O3sH&^VNt^W|>SmX!1{0R7h^e|%RQBYR*>Phg6sI)1lQ_d7d3T3XeJ{{Qa0WW>2UH-TF!_P)}T zZXL)ilZ(QGy-5R0}u1=HfXOTgm#D&0{`PVYeBU`8$%^QtyyN+9l|rl5Q3+ zk{@9zK;D6OIXd=I%TFnvQdyMKsm#al=wPu{ zBvZvVlv**^P@huOMZSc|i4v|i&lDC&$v6uLEM+S`tlf!VvbErgCWVr+g0fE-55yHY zz(e?OTA0dCW1VrLcOkGAs`N_2dKxBKKbIwNnJh|J zS!D9^nF4*YMgBSiR`gFbs+KemK!<%WzRF`!8JZH#lKC5u!D!PD+*ssddY&FL9tb3~ z8MTE2K%o3gcA*H6W4w!+FIjL3y(FzyI_kku`H*mHdpzP~p1D&7+MXes0xqdDlKz%2 zO-pI#_|*H^vTPOr*tK`j8@VClq+%tf#e7O~*{Ayv<%!j)4xih)mUR_{bUmVC`XlMx zvM8pS(BJM%AsbZ%v;kb*t9M;%QsX&zgS^JP~ngDasfw;yWOc#!%iqCVlVZ+o|HszX; zW9@SWQcr~{Z^F+#Yj<|12fUvrigP9lI?9ChvB+ac>F;$WJfcmdRfZWTJE@F5^z? zvub&=SAHvdm1T;&OmkAou0|kyHa}@%S01o~Knda0L^=?Z%B)M4r(_590GR~4{jPq~ zph=ZC?plv4`B`PHD!sdzkiS5qv|m#L(EA7I8{}qQO!Syc(gSet@*ekt-0jy34!L8-@QX+-!sJKZ;h^Wjq=CY$Gz12KKhk7b~P;nemo70iBk>F`{0aqBTfFD-XRIQOQ7%VgCT zCZt0t(GDF4{8Tfqs-f?tnm#}Q{#hG|)BEy#$K*VT)pz6p8SRyvlbmCZ$?ESmSldN5 zxTi`mP}5M|VSM+BudXyo3OO%?Y1VoJJCW1nKF_u6SsYSjDV$hF zl2@d`h_k8X==ol*pVXbjzovxgT?SxYpv=+7FBZ{zLf*TDWP>LPNu%F`puOw-gF;>+ zxkH1Bj-p*fyMihHoua%TXke8$kF5?yFh6h^ezEP` zF;b8jv&4{__@0Xl**KpR!g)?Bj(-H#woR-a78eF$m|xyf;xFUsA&GweB=3I8EpXR-V{*n&Zz1cH=~s{sVXQ|-wf>EFji3>v;Ys-!K^^Z$-#@~ zr;pB1WDKG${c2~k3BzR!-O&K4{!+IQx$H~_x?C!U(*)o`yY*K`jK~Bh%QW8sQvWKL z?sRZH^zikZ{PJupdzMgI(YF49n0yG3M((mZNnB@hi@=P!19&H2_ka%Q0!dDZoe}$% z4O`o1d6&%@bC;#pU(Xtv9T=;n>sZl=|`_-sQKC~GUW6* z2#R(Hjq4kJC`EsG;iRAayqKy0KG1=%Jh; zGhk#yI1Xa?tv5#c~#0s3LbLBjA`BP+7b52cdwepb(^p1353%7CgDZ%sC-3x~sx% zx7xj+!Nz=3NaqEolIA2GKfyzBItp>FrG60C6jt6~gV}3~?-5zF>$-)SJWx_L`FCD$ zE6I0>$#1orRJ66IQ7p1h8NV~ua}>dzIgVgCUxholV4;Pp_Nk z>qc6=R<8UISbhk}@+smoLIM1SO7tQxYf$frb*8N?cpVFyUQ#)T-_yOR*DD0m^gIhD^^=7cYf6uwDK5&-mJ*1 z#|j}i?2Z~&1q^;UN{g7G#%!h)$|TV-&>y3g|R3z*G}nB z^4$CPCPz3xSy>YJ3%1VzV3#TaHA9y#32Dw96y7Wo2aIK~!Yk661BRrh0>S0|6Sv@C zXtT7M(~zTI8WqzQ8NF9dqRS7@KN8@R*{|)YOh)Y?9~iB8-0AXqi~x{`-al_&d|YSc z|M@o2+$qJcd)kHMGC4Xu>w+L7;j}l}FPJL*0Y@`#V|W`}ouwnS538PtG!hi5QRGxe zv1fd$i3q{Pgaxq3HU*rXc^+VF9R=L@*7~!uC#$XAJXT46lQPhd) z%DVp*7Nwo9KA8)Z&3DVuncvb_p6xszGVEB^T8?>;+bIW;GKtF{7Zit)awKWgq zm04(&rK4p4X!j3T+~3H&4R0le_w|ZjuQh)5iO*o?cl6!{WbtiOIDp>UE*o_N|6c#p z$2Z0JAgL34s;4#mub44(2CH`Gf}$7_csXdS)tW>=o0RW^_J3L z_d;3ODDPf|cj3llRbA`n?_%%=k)f-J&X)j);{Q0(?04|QI*I4sfmn6W|E*X$@oFc; zvurEw@fg>paOzAKzCz;e^y|y^aCtXrJ0GdG0ExIJuqS}S?DaEOPJ&gP6Tgt|1E_Ut zvq)J_-zivWTYA?9jWzxtE zt8K1*SbF*M6p**^M7qp(fM~J-WP_a`Qr8}PHe7koqUflF zvTI)3*~_*TWrwa`a@hd9yFdaJIsLj%+(27UWB7B7tG@K(2!~J87p8!geX`c%M(HnF zZ}$Ry+sJUo8px32GK;i5-uQLpWXMp4++bHr z_9{!~-la~luIJyKVzDiMdmQnP3K%{?t=m~ZgZ|Rt?K+`{{zfXU4ljSbornG8-{AO3 zcF6Ht)4N7?$l>po2i8E2|DSa@KM(-1JOqUQKbzzhwU2?vY8sxu03;8Xq*2WM)zTd5T>NjD@t zK4Eu3{(dL=4)2`zj)+{yM7NLcWYnw`tV<13*Opp`Ac{4q&d zEh;bxIhEJK5Lpm^A)E#%sboUa`f+t0n)d9!qG^o+C3wBYyJeGiznm6ozkxrqV4GiP zRdD3+xgA-Z?~Iwg^zWsL{7(jBUcn@kNh#N_f!)tH^2`uhn0BdqYJV;7;}Ju{?e3 z6c^T7VY*8ew8Kl=Kol7JPV_%6m6(UPUf9Ssj7G{RwV`|R_`1A1B!*)AQ2J`UO5~vFR48Od!07z`~ z(w^@u5GnhXTF}Zw0x0V{ulba@+$H*$2@_Oay92294ZJiX@n7Qb4=9rM8KLs0g2_Bg zdnEP$MP3t0UG98mbN3aV6W^BkFNxGIK%*UmF*GVbdozGDubXx3y!39wG@YR(odoMY z0|<5OrT+_J)nx}AyH`fYao)!tLSk<`?aHLzU|Je-R z2^1kW9(j52F#nPnUfc(GbDkNVsmR)IaZ}bpIi@*9r|?Blj9TB%mikDjt0+{n#W?}s z;y$I`sS)NO%m==z=w*$Zt?tiUB!+=l3Q;=vUlHbsAeof4_YCI?>r(KwbV;&8BU<-0jG}SmAy<#`Gk7Rw%o7#&BC!XKuGOLZu0OLx&d+ zK}B_YPm}`NW_AD~?Z8*AM^ON|ae9m9ci-Jqx#`S@$V=_!e4#piafr@c%KinbQ-U|| zE;zt*q!a%QGYmVaDE_iXKa5IG``nD(u@jaL<=^CR{Tma`Ch2PnMcq z@e0W=eM9RlkbJN50Illenf+EZ1x)ad108;$zO7o9YX+R*F|$Ab|6$AhN)!CUmixn& z`v>0T{{UO=4?XiIs^K3#@UM>e4sS~#( zeWOW5)f82?@$sBg?#(>|4JQV&F6?(0d6aZ4O2HB6`er+foLl0OKY9Q0^si(2>hiIt zfhU@!2h0FO6|+*;vWzuWdejis!=w;vnbNqVxL)>bMj;W`*!fHH_4oK%{DcYFpEd)P zM3eM7{_>`edP;}OVvMHn{C8<|sqGqKf z82dL9?j@(0dL;|k?KOH%EC~g+97il~LgX0_qjVudnUB@eE%4nC6NOog--L>q>=j*? zee_}P_s9Fy`Q-)Ms-sFzmzfjFRB0QR#0eCx-{>32+v9Si#0M-%}b z(DPcL&0R*0l=$fo%5SeJuv`t0-b6JeJ^D9X#d97)dN0 zSn9CdE7qs zoFVr#jE)lP+c3eZ8Y7e^p{oj5*^;V-&Mh*+7G~qv4D3K16Dk)pSqc*mV2nkGamTNg z47%*p3zsokLPpCl03=Id;YTd8a?!K$DVCy-J`|$T*mV=2cu5IHQnS@goHO_)${$6o zZGrkIIYavz%DNX|YRdZ~+g9SeDhEaBC)4Xo|7L2+aS|qYM+g(D5!@j)0MG!LoYy0a z$IY)%z!)EkR^!Y8>?noX#42p%?$Jt}8yRlM_J&hkf5+-Oyn4y^8Wm%G9jZZizoO+}1ln4ou-447+6 zuRjb3$L^k~+t^){DP+V72brHhL!*N~Paf_n3i8U@f>W|@?ap7Ye>?OWoe1Ph*Y@J# z5etT;!u5@TK9#T|*J^lFR$8}xbjp2XNY{+7=e1eU?)l)JMrc5<3(Ajo9Prg5N^NNt zA|dczQu`+xLE|C1005+Uu$0RU#b}qaXjo4d`IE#}FW|ICC?)?Mh1`Y;n9}QV5@`+K zO!qASryKiGwsMsXT@z6zgd#@9bHmzr;Vc2wxlemmC-C{U9;P>}J|bcXFr!_OA!AR} zPPi-tjmD$LswPFEdV8sQ)(m4}Xf)dt)+2=*mj%;@YhSn>zuxv*9~9agOrZdXDzz4B zd);ig-PtIGsiJ&;*LOVj-wwmbzVxvk9uw_7?x86lp)1m&xN<;3va=GnGry6L*7|g@ z-VQSxBJbEw9`&wb#&<~H#T~V}P+QcmedsT_s<{CTJN>3P9)Pm#W(6?!_*rR|y1rbW z)eNOAc%;+=5qc_;3NUD?5MUSI`j86x(-d(#9(Lv&y1R;ZB(R)ADy;s@CJ-qrMz{;@ zUn_M0SDxx&M6Q2?1K$Zcoj%FqOtYlc2wzAmWRFUrQ;%Oq&xIk-*oiQkz`Lfdxys}W zG!4lbPm9R!l_f6wu$Da|7oy30D3shng%9Rnr|0^Otpm*C#Za97(p`ajWURI>v9R7B z;6Yo@YK;Mh@Bk8e34q(HkW0KCXZLu{*(<8Q)c!oT|*s%%C|Uxix#Gk&C+@x)RpO_z!^ypYR=p-h^-9k27weX;_U#Q@uBHO zHNYVg55wGso+oxCTenXmk!osn$~!)eV_)G5&sD!-8*+vWM^0cE+2fi%L+lK#g+r)h zqql+DAI;MWgGa{D6i$aWB8}B#8IlQaF2pH+yn@qfrgJ8u#VNFk$#9xy$e4}Hz{T?A zj|4`UYW3kCyOu5c3&HzF0HdWdoB>yzS7)YhKRx7eWUK&?qwu`B*^)E{b|MbFh)D`k z(L)>hb9|V`%aFqB*8)E>6VMgc zp*^Z*J=#k7(!22&gWAqY>I)Nc7Ow;buAZIN z9i`YMRW@E(fIMh-H;CSnzywgb9fOqpG?&`h@ekt;uy|>=pm+yTU(bJH2dz#O*RFao zsd6UAccpXx(V%H0iBiQf;a;r=&ftv7a*%eZE@9x0A$tq=$xUu7+wcfo^p$kt=CRGr ze3%(_8#&IvAI99&U9Ck(!9d)^*Q zl_llHY}g+>yw{KmRzYxcffGT8USy_fq*leyl69oY@N=crY-*SaBLtW$?;P$ic6X7~ zJBI5--*hvsEfsRZu-XRRx`ENx-A}!3UJe`r8$wlEXsGVS6tejMGB>Q5S>f!jU>GUA zMB4xd*UtK0ma#v}%1cR&n{1ETA8^m9+u@lnWNKzS0>R110F6GX_(psJ`b{}*X}^@} zyal})>Xu)Y37>eq7K;m}`NXOcE4T)9WEH@Lf8E;*e_g zY~B@y@Owu8=gKcpcY7Xp&Zj2;EpysDllE#(R~Nt?GIz20iS;2A86PR;KSpAjd!Rdg3D@O%Hv*2f^uYljw5t70q|nmy-8RF@7I}N4cY3&DcXy) zsL}lL(d@IaU^W3bT{*ZRo7ywa_%Ke{c7$}#YcvZ5QMrle2V>33><1SaIQlG=&}Cl8 zzD478@^VH^Ox@}2fKSwfFTjG_n6Fb=@9i8J2M+hNB!UnCEs7&iYYdM2P8KOgs&K<( zpY&ywegTeP3&d;#@NnS&82L$JpPq~AGuPn9B860H0~$g{$M`Q0=Iz&jFf3&o1QvV5 z`*TzO(PN&R9}s4go}n?oaQQW$(P+}b;Y*~6t9#;KNffiXfh=XF6} z9GUIjn%8=Htnwce0mT7L1k4+VQ!BeXSgyAp0@0-^jz^LN*a(X@do_QgB;S&h&BnWD z+S^v7YNDCcZGItk-onOqRmTl!t7)^f*}+qr9SE-@sj7gF^g!8QRQXTRS&Wg?kXaro zxla_aI7!^1y{nb?DF91+6TBX}Fp^z=el)!|Yqe2MAv+AQsKd;u_V;5hl4*+f4=7aO znFzWD$QLm@7IpBn*?*t=@V+Yt;8zF(^RYVkHCq0=d;&OY7ugE{-`jY)5T@)=M!!`A z2yt673LY3t)+3N5mE{6urXDCp_9(8nDNG1PqSe@{G)?`%o{qMN5koFg>`-y3C0M67EW|s{oS`yVf4tKGcN=zNy6f|Bsa-y<6*^LLwkJ|qmSU$h&&;j!E z;TdUd8zrQPUDjm&L=ZpFNH^`$`c0R)VMUfClwQe%@I4fMjpsXK);fk?nG7OLZ%Yu^ zb+Vx3efJ17@XW$^Y%*hr0t~`sOAh)KZ^0;#$XElVtVX#&9372x4OZGMQZ&!N*1BSl1X&_FD4 zk~DXpwpvozY09%+d2p6Zfi~egVt69H9*!h;#)Ee4_lqgkFUg5!qJsmxi$z`B^|G}%8ey;n-==Sqfv%PLu z`zXbc2U!@b$iC|05bID`i^3F1L*PJwDxfaVK2kyRL-Wcn?9sF4^?eX`HWXzZg>2)Hr#dK0n42t(q#p+LUVeaKkh98zc+MJr5584sOCNBB=|_vpYjx0&{0kySuP&#=b_J1ZY@aujwF& zcMSlv``hpTCd{qdd>7dH6jgbvGKdEdE}hOk{N-9LKcnrw1Kxi3*W=zmX>Vn{tmw0w zyQkwR=N1n~>$VH;Bm0WC-sPPpnU`G`@|%lBLZDTVfK$E~5>EpWJyM~_q&oDERskmY#{hW_{*SK$eB}SFZhb}Otp~$c V$^DhUw~LpZ)igMhf7v(gc$-_#D&-@4v-}( z$fk@40Rj$CR%8YeLJ(zUv64VSlKUKR)%Npg@BQcge)sENwVa&iJm>wq$Ll@MB^#?< z%8E-AVKA8Tp4~eRz+f`TFc{){`MKaPtI}>2nN%I?b)%_F3^=l zS1EgR%tcS)z4Q3F#TQ7w`w)u`=G&;1MvT#ZdftBKFy$s=pT3i1=|Nj6Y|bUb9EW+z zzZB|7w#Msj-SJ}4LTWT!H|*Rk%`)N9^5Q}_ne-EyDPy~St(s!vjLDtG&+cv>s26M@ z3ra`>7P*eOcQ%v#T?ArD7kdar_MTNuQpZ;-lq0`dvB{#A> z6~4S86AXvn-1)`N!2;iAhQqwF%2w$8c_DB`6*ctP%I}{Qk*_G9TwUvZ=F>Aj>;isM z`+wTO0zmVwWBe}h;`4;n}i*I;?VV=nVdt!IyMUnGh!5Y$=c+V1=0&4n;VGyACPRIMXt zj9O}8P0*kgdfG}$#~aOa6Y!ELF+L#V@$ze%?psWcJYyHre7$#R2fefHFP$E%qwUd4 zXuW36d4i?;Q`l!2CQDAlY*f$gT_(TgClst`rJ>OZRB~J;%Vy~Hb`>AF?@m=>2~UnU z#PL&9a3_ec;k|Cc1GFcnkD>{_E>fzY(~`)&6lcHv=5Kyg6XLlWk7#Z@d8e~1A+Yz7 znB*41E@DT>X`d)CwQ6&wtITrETIUEY{9y5Qi#j%e(S6NqIQ?#Xq87?u8d$y9W8zKg zO82;IV@K4nn2lYM(|t176}sWgnlkchyd#?W`z0c7(&nK;mxi_iWz3GraP(%q=fJV^Lwt30o7zhy99ocfM(LM#_E{ zm%{vERyw>C>brGZg8E{`uedJkgFE%^Bu>?P6my;n9;YP=Z^AOv;+Ro>=AwuMi$mFo zXG~?xfY4d$qTCNEUG?)T$h_|9_@?y_!;A6nrL9@^I5);p)JZO;kR&|zqj=vq!PI`b z${TQwzex^S`H5k^d=*w{3~6XTCN9Xi>RQ@;R#P*rQdZUJ3>@D2`;EBku!yA!Yufsc zZ?G1GrZ+t^@z60r4&3vjc6ubyNy@pjN-e+obv`+Se&tt2xM-RS8~4NRtzMD z6#JP^yxuO7^G!TcyUJSO=Fa%MXYTMl2GkuI*A@S-oQ>iKmQQl zANF2z?u)5(7NwSGFS1np{hFn45RyEe=~_3KiD-a_JckxkM<|dF?pGDx@m@z z$;Hrfeg-E}*y5aazh5c&xZFVv*T6ScNmo2$9Edet^s#RM;mD;#Z-h*-!c}iSW6Jhd z9;sx(Y1w7(H~WhM>m+`ID9W3z(p);Q&0_p6nU(ZrLvU7Q4kBGkfTR}~o1W}=T^e$+K+J&-+J+CGan#G#$B2t@iZ$1|q;!YfZkoc8fOj%2vkL?0wb1}CSe zMcx_{SWElmA4vL>yQ3(zlT_uKgR5^uq$kMkr&~n#-(%h;t$EiZccqgID4KC4YkSXQ z1Y)Jj*6_+d7CF^lIJPb~L2JFglrFwd#@X6OPcrA06-Y@OhUG|vN}-z7ro#s*l?`gM zve5EFUZw?kFzN(doy^l@^(!X#E8+TsH&h+ryIcv%wxFKn#bq9J1-tIE#{}#Rubhv# z(wVoWzbN$>r(DpD9etrePzf73pDnU~mb!L$@@))PXGcYtIM`(Ahqs8W%;9etx2QRj z1M|vZBop2$$5i{1BZm`L6`kp;#lBn`Rrt^n4m6>CcK?xEcY#8>tKP+aFRLHSD?(-k zVA#0AC%nH{{knZXrPnKRuz#`oH-HfDKhzccX_48RBaQ)ed+4}q0P!zZ_9)3GuQ2en z_|m{vt_P=YP;idu@|PWx4*oy))Oea2QRXB{R!}Zdz_9#9OZC|u=mR=cM1V1n|s?ou+v9Ow8j{$ zwO;=T?Y&0z=kf`?Rf_twJ||bkrmA}5 z^d7k8gD$JQt-U%6ULzNjPj)Cu)J=Df*2YDkoWUBt-`-7mN*PFzhC79;4fS+%D8d_G zi_&@}Kf%;RWgCLy!<7&C>UZdl#>fvQU33jSDJ#9K`5^FE14`G>9j|VBwch3pvjDcb z8J?eyxOd0x>{;*wkv`n6uppW%d-D1j&8oSV@HbYJLu&qbVui!z@a z{^f$ATav8JZ}TrkXRC(K6U#Kc&Lz!0y7b7zMrh{_|FClpxZQ&}M~=$ctRoE$f>(nG zdBQ`Jq12(+>)<;J3d&vU0Sgq+wOm>gk>GR*@W-XR~F3Za8 zz3|)O)uFgWs2|OAHx_S$t)Ecav7IEDp&5&LxrrtNh=%@mL{mq-n7*G1QPoW$`F?^= zsFC~mn>6{2y8*KZ_3mrvpsPr|b62?@G*-Q+)Zz9*dAYIOq;O66<+G@hCnI6M4C-CH z5GJ0XSqDWum7oBKR{ZaXrXPNJ>c?GR}mJRygs>0X2N=hk!A5Y z{0rp+Wq%ykRn^ENUDn)rJbi?B5T-s94)NH_*Vd}aV&M#z=I83~#Hc|;ll@mjbAYLj zU%8e8sEYp~{4%z9v>f8m<^PUH)e+ZIl602{c{4OS;^ug@dj&)^#eYS#EZFLP4J{Wy zRY*SAC*OuZJi7DW@uW6DWEFEq}0x8 z0mP%x|Bgo;veVYKC}w8)!M6s0XZWZ3(yg9a)Z z?!Vb;NbhWky+F{4HlAb|(rJO9Rj~LQA;dG$l!}IaxZ+BF36SpxoI$Q!VfS^aBNT z#Jv=`JBWK_hR?vk={P>~XV>Z$bLCA-{Lfgs4<$0d^`GB#IC(X|T@=?`0-r=1b$ zFX9dQatMY1@ zP20IT;3raaP%YR`iR?n%k-u@AzM4?>f>K;V}kV08E^gYT7iyGVG$*XvVHb+FjaM;dn=3n0K2&dxI-N^ zdpO`GDmtOE5L!|HdvKRV&%3r{=vX8Rz)wiAY0iLO(7JHM-}y<1bl{@Rh1&r%-;N&n zcqmSoswU1H z=Ct;bl^<#WzUt3^V~oheF_!>(+qF%-7y0!bmIZVy&Bv!9&Sd=B=(&ebRw1E}Q~y1x zL#yba<}>`zngEX#5@Krc*`sCBOa%;HM`|{91+$z3fi=XK;;zD4@8)R&zBmnj-v6DY_p&L09**ciif=H)J5qjZN9P#0| zGe*4iz}csqh7>%O?DKf1Jro)|lp&mtOMEDw%*w@*lw@rF@^%Irn`qaQ8R0C#1qa!Hd zcx*21LZnYcuNb?`Njnj5it+VG!eK*tBIKyUZJD&Y9*Lb+o8J7&pB_Ns8>6l>0vI@U z^IjvL+b-CH$VVc-iKbnID~6>f`hBn-_fGGm6?(SZ6P#)=p-hYi;1{i@?9hF5asi;p z*43eN0AOG)72ohM3a>392&X~9W-7XRN)g_=$$uaT4KpD%ZE^#JD`}i*Kc(b;(VTK(0b z{%*O>CoIa&~0gd?;`s6yw+_yAHIwUjC7iP0zHL9HHDk zl5umhGhR_te}8&vmkbaN5v%+Wj#m}6T&8-3(j?sN3>=$3-Oc2CsvswNCR}^AI2Ir0 z`m+~Mk{n1WB5wYfW3;gb%%iCvHNp;gvo8&Pf9b6Yep#5ePg_f3yEg5Q9|X&)^MvkJ zs0IQ5NrJ2m{Yi^8vK%iz<&0l6zNqX)y~%|``*}C5P<^**6c&9tbGI3fX#UX)bh=wK z$9F06W;`Yu{CZy7Lt?Ej-hW&=UGp&h)KlZ~5z+ipVzrG)<<7Fw?0~M{R(^{Pe&Ltc z;*76fyT0^SX?ym9bC176TZS(Y+vSFAY>jnv-8}wXIz33K`$&IWtyj$GeHojh&s{28 z0cr(UTlSmaeMU{lqvu?D3e~4&ZIssVzo|G`aq6YY>xMLBtqxnoZ2J4X%F{_arz<4jeF7Wf_|w zJ|f*&#CG6rOZOvN5XyPk7`C)r>6T&H$WtuUGp$Hjb7P@V=o=*(;4SZ*(%}rr+tk-X zgQ+pTx1G=IkqS8*+JI0K=u3be<=0PpQ*wL=U1#>1kG)(jqO8H@th%qs1PZ>o)|~rp z#%=NHULV7j6?&X3REm3Yjs)}C`Y@5IF$HI|-g=|}+^pr0!-VQ<11{2g%?Sh9@sb}i zJ_+te2h!lha9Abbo6CEVQ82He{TjxpFt2ro_DZRZZm^<>9HoSxVMR8NN6Q+GldIo) zwJ!rIxN93wHV&{@%66e1*%cP+jqyjc7(&kTbzmUfV6l#K;C8@81x_7>F>DK6F#zR8 zpQ23WaCPn+W7Ru*A)jAk#tHpI$&gGH6y)1IP-Ik^!Dk;*fi3~Bxtb7_B=aA0D7lljc4XDFlx>fDK_bcXJz z04)6{zC+_+(BGFq=KOO#3mDeo7hWlJjS8gfqd0v2lrIgLS%fEmC4CONVX<5wUYQZF z)w6Wu!Q~+r`uopW&kFTIR5?$ke<8SoANd)9XLau!P;c3{=rI)7c<(!Z5Za}g6i^4% z*A7nVZ=YiKU!wwlJ&^A|ceVo%tGnfn>#n^GInzcn=q0^uiG1>V3DyaR{=WMkob1m9 z=<{O-bestv%+S!voeD)H)EEAT15r@@l({3R*^FAQofA<4u{&<*e-Q(<_xgqyt3L@I z6Y_G{A9_j=2tfR&WB*5v|NKDqgjarwR3!iy8L?Qt9#n7DvX%Y|?x3KkiTG-dX&2%bcY)_y8pzxOqto(k;u1C z5mNEme+0Wg_xQ*{p!M&M{XYkC{zfl-9&EzvCHLao=OX-|fSqXDf8pK#7G(Md?KgW; zGLKv!WwP`?9LVoqRn{l${Z(bn23h|fRaqSYbSfVt0Tx%rHsUTZ+c|C8p_9xUp6X<$Fp~osX(eJ!WUn}yTa96(|8$!PheVSxynkdajaZ5*`s@}yaY1ui zbt7az<5Pe+>^^sU2B1LL0jlWmb#zLoyaCy1tMpIs_%naIj6C$u;c{uppM0DAv71nc zTIo*~`(v3B3ic^EfCcXTlee5ko28YDgQa#|Zkypn#H8P32 zQGFkB0BBQuv~LxJ9v1EX+{ApU^7$+o0s`Dnnbe1D21p6K^mv$W4AB4-wJ2wOFoMAB z6cIj`aO4E6YSwSv4`3eBArTaf+xbR(j+Q>v=g5Cbdw{_IOqmc!C#b+cjm*99%b^ZS zm@=(5Cr|*If#-Oy0Ny#IcJcK#?3y_K*gr8%QnJJ3!H~=H@SzG!L06zU6lhg{ee~*gYEHXR}dEj8N zKj$xWg2Q`+R9dU9PctQi8EdS|D`l$m7*xFm z0cRpcz$0}-1sG6a{KwvrkN;qS(uIfSpIS#gT&P~H8@@z8_tjsTaX>I5_{$a*&>NTq za9sXa+`TUwUO@Ehq1@NyzY+&DPkrSang36Um06R1#fMseuM+(Kj(V&7Msd2QK*bv` z9Kf%ObS%!>OT|cWG)L{p;is-!B6H1H2a&HO(_~BINXP!)fs$Uk4|Ei+W%g%$)cTi| zvYhhoTL^1N79GT{@e)lM)F)NtDvb%HT%omRqvb-xhLF;40j-#E#h0IGf9k8uHUVzB zCJ0D~v#r8W7QksQk1JjN;d%rD^8A-6e!cvaETF|0u!E9vvNn!2itVUQ!U)2Kbc96L za<3TdIHv5OMS!@9ofRN)5b-iGgHo}aikKr|IOaHak|aZ*gHVcVNMtnQL+jM63*VB< z(&@o99#l7J<1zM-h}&X0GK!`$4FdzDqqoa?UhH}I&4%OxhOMDN5)I9KqDFpeXGW{G`e z#7nW$Fx0}=|IGfuXp%(rvzm;Q6lo+$jBz$?;vrXjZ*ExuPEsczAWMgP`L8KDK1hB6 zsaY!M_7)7+IG?WaG|o-3KDrLxN-mpJX47?oymgjDOS%HMw`sZJv4?CQ(x^UeWn1K4 zqu?PWq3)1<&2*@gYD>oX=-Bz=ODWPc)3S1vKXNANZWuTGWE0=d$UIZZ1;yOMHTOh~ zrj2D`6-*6OgWkKhP;2u&(7=4!)ph?z*yz3tu~!RKK722r&@HfUFYT}#-jRvOye*}r z(vHShbazh_)CdGE3D`sV_lp9YdFHq~JR*VK#1S`3shqb3QcInu2jpx95`upgJ*N1a zWP}U`6dC53^-Cp95}|eQGi%=M<^=3;z2$^N+7)$>AeBf5pQU%|B#8x6IbfDS()Ae7 zIPZQMsGt`?K_(v?yCW{iuUQcuCtg(6gl3l#+8e@@=&>ZB68o4@uq(6s4~nb}7bzH^ z1#!i%>lGf~XcXb8=rBnv`-e2wwWBaEK6+#)rij^n$W&x-Ryx$|b`r@~BV*`jekpAD z-W6+aE?#7BtoOih6VBFr>bBL%G?H)&ymg0J>$3kus4EJ|{lQAE0l(3WlR3htw`s=G zA!Di@XS(jL2P)bXKPo5jGw&6V2;u^aMPB*%a122o&zrEk7k5F;hRzEhhYNcx+`Jf) z1V+2$U5usi_`LKcG(RVpT*MY7ijAm*Fj@=EgfyCj);JJt*)P+&gF3q#P@3#jW#_v4 z)3KdFG{G{=#TzK-!S2A`3wN&%6$Ic@R1!#7RNL=o4aL7&tu!8 z5DCd>qRLjo#(%4wAao_X^<8l5PtcoQ1#7Q6DDap# znI5a}*!b%{I_9aStS$Z7%U$HCCE}gsoTH>Lp0aq-w48s;Smk@Si_eZ|Z?hCFwxM^? z(L-Sq9Xdn!ak?H}il3;9PE-bS5iZ}o{#dj>Fx5RHq+@b$$|jYOPQkIB#piNsNd3x4 zlX28oq%$e%y3vkT-oOCoodO24vZBF`+R_6J)m6BKN z^jV6KxpfIp@$NTNpN~Tp=0QaaXd|dPuuVZV9Ypgp9Bl)3d^kaynOcS&UoU?p3lKux ze`JS`O*gN8}%l~7@lJ^=mtjU<%8PascwY6gM zg%LlX+^a2v|3(G>#DaN{?1jNG(KM*!PAY;5W}$LJ2M9EVNyP%Lw1VOq!qdTnvOOwU z+(j-zPCrAQByY(I(Pe0S~P$ z-ANoN)WN7sKf6JZ7FswI*)pcz@Dms!7T?^_wZpf3^c7R;L3Uss!bZ|M_WuUxF%C!(2qQwNgSarMp-yy~~ z%9sSO-uaqy4~GI@}%qIs?K)TE=mFD4#K~ z-Dwy$T`72`xmWTAgV+k#zMl;V>?Sz6=Vq%sMou;p^`eQ}j?yl5@ch+9^o zk~(`cF$~@sy3NuGD5Mvl3TV+lU)x88B(q{!-kmQ{A2+bZTxh#*7V`OELeNl2Ilzjh z6{FGYdSmiQP&?iHd`@~Nrb@HCEhogA&LX@nA1(|i)|viIgTOHeae?_ieMM@Q)J2&# z+`C>Xbitf7k_;KsL`I>YVtSdD*8+$MtN;^iGRp0~RHQ@5c(D0+DQ z-6~g-();;r!wL}6VHyM`>=w*R2i4sKTungeF4069(gFg5OopIBPy|f~!_f#hv@|Q- z@WB$?8j*}$fEWUu(y?-Or3bhno-{pCZ=Adn(@C5X)=>k($OY62xiB=5cY$#-H~5*k z=oE-E2Pdg;F2W+-D6L__hI5D41MZO!WDOHlr4NdO4bnXZCT06Z~LK1S5@SSh_! zHzUNdwNS=}K9p9{y-jemrrdYMFG7)##7Ad~Kipnphd()LAxx2zi5PG%J9eQ$l7$OEaf87 zNo8+u-DNma5-tdDmGgZI=A?7p^wMbp?XZyu$@^bwZ*zS+Ug-ji@-AJI@LYO+FJ z9v0Z%>cm7ayVJsiLurI)P18p}#SM(t{uU1V$!NypM+~Xvk>akIt$AU;W<&U)%SF}e zY5)j6omKvgpJMW@ZB5^SSDi#6JmGc$8$zF5P-b7y^!-+}CoMre)RDaM_4ikxzMRIo z%$ayr)bNSGtc29FpB5gAs4&bDkTOPc)}(aU8yX#n20ySXpmo4w^oMf+mvxt{X&K zw1!*gY0^$XT8R7G{)Tu$d`RJL!UC3+U$iUn?S4+6(0i-A?c_3e>o<#D`~c1As{G@d zJd&67W*21ZPEb(eVD)Kxb8S+iaJjDKr7EXp=4c)WEKC`8Wt3ZB=%xN_s&_71CvH=) ztDS&(j3^)mNZT#y%R;JKfmWxN_GD;MELuFj*~Yel`PZ+jDP)Ms8Qv<$=~FtP?2{bM z&g01QhtC?goN5=l1sKpLIfrV*dAvP^*)Xq#we7RGYaC3yX`{~cS;vzGZH0=Ms`%!o zPu76*{9fU~TpPLtC>8G%pI0FwQ*|?xZtqOHyVEeW!7hM0*dLGS=9sqlt2l&K8z(j) zo$^EIB8F6gZ9nM2!|RCeGK1uaYAFAz_CeG0S94B#kPm$;7O~V^Q&UG>bT^vM+nXW% z-9SJ5(E6X$W~P}%$anjUH##Pf`@iiT0&Xfx`m|}bG(a&p( zh=TRe2`j_nsxzLffaKh%5y;T9T&RQQ0;u46&o29Zgqa)>XOe`O*s)i7WSO5+Z)_*} zd8LQUZ`BIuYVs``tqv1{d8uVCr(dYkEY#cstWfN49sK=by1fKI>Eei|euA0Yv%X`I zM9lXAN4QCY_0ueeBEV>)Me-(*Mrxq`=f{s6u+JW5(yNuAdWeYiVQ#z6IcQBUxW^$$ zB@(V2Xxg#T&%YBD7GoX`I)mSYK9x_tV5{#qqdtjx?)ah*-@B zw8ftV!y^xc3DfYUE_g9djOk};7fSj~NhJx;e9EBL!t1oXFq3z~CD<9F&qtWgwE^vD zZcwLtG^m4?IOw606H$j?ZL-^tEmrVBOgDHdLwm#0>(sLS)%T;J10Fw#(S4XAUPq z4u<;)aMs&1J92UzfeUT@(MP_%1f_NQ`2z9dWDYqcN}D==Q%^Ymu{AD38WM2e1fRy< zsdpLb0FfL84IQi`Wq)=Nh3|JYEVXN8`jYbFV4&{3&A8SE^T)6ZG?P!mV2pxZ?LE#p z-D(NJ&8~`pkcu=FYJ0Am+x<)#n;&WiSIq3D2T{O~VDKg*M~e0EmB5$6`R19P8L5&Z z1cXjt=7xw7=~cHHBh_@*K7`cYz-ex|p!kYW#S`4fd7XO-o~1p5>bZeG5+|DXCXFpB zGW?ZV76fLxvI+=FIglP>cQ<3(>>Z0*tjvA-P3V5{h5+kH&e{c>8!lz$FzePeg6Pq7 zDnVHr+i{p36*d@k{E+*^cj;Q&i&wFKlXbZT`aJfz&PEoh!3?t^vC6ud`Ign7ZhfXV z2!va7gA@RVELsXEXwYjWCurtOw1952BKoDwPSbT1%uCI#{f(vW@Yd^@1JqgmUn_I# z?aZ)%?r%U75Be@FA`Uo#>)QMaw+1(OyS9Xlo8tY5_Q#Poye8VK*11$af25Auf0epz z*5q|6uBn0MI_1(q2M^x?3^+ik7TA)5)s2_b`7W^ilnYY3ZalBNvNZ?2pDNnQJQ1|UyK)1O?gs!Lra68 zizjDv2TLCCJqdLJjxfuw&qj$}E^zXNX2n@0`DcP=1y=UHaMXT~ZD}aAn$B6IhO!v= z)$N3%^C_w=;=Y0PmL@ry8XZ479WyaqLt3FFj&t+XA-=m!g@Ujlyv^Lfn(WSTH1!~N zIvY#TsOYi`9zB??RCFAlrIcYv;f**_ag+yx*f3C_nXmQBvuxKY8Dt zwOY~3!lO;>E?+)+x7xPMEkcd_z*9KjN$As zPDYr;M>&K-SuHosmw12~xC=~5^Ibq-&Ks+4#h$kd7TdgW(fc8o*riYftIC(R9oh#g zdTX5QI(BZkh|tWQTvTSbTX@OV=uS|33u|CYbTuomz`QZvW*)qC+rzSJvxL2Kv+_eH z--BD|F@L^O$sayb{ie4=yZ7vWGVE#X9plrv{&8Q{c0htLz^G$CB-Se%0y@xzaRt2vT)G>nudjCByyFshAs1UTeI|Ky%&+ zGg1qnZU>Bay4kAf=1qfCL0xYD(UP|h=Otr9>*&rJT^50%HM-sb z&(eH>pW;M)`e0_+a1@C2saK-UrGSBms133}4vz!prSRGB@>k8nhr7cBjhV=9l6|hI z;4N*ubYw^^K+<n* z0A>~&Bz?JxI%M5EA?pS#@~2f?!lwGz4{lg_ zeHbTRol|kF?pE>PXBXPVb|W9An0HXJio;7iS(r1j}0k(7**pEU<#0Bl|rOz;Z?Xpuaa?*5>qf<5^Oj2BZRw(1|6oHg8Y6 z3}zwHBYz6gf^_eSwf4Z$pP0OMy8_H~Z?_=yfN9TOpb6OxG7K~*Hh@en`_u*9NSKdW zaRdfid>Z;jK%+kIvhG8Wa`*)_tbVK)i$= znu?^Ca>E2fs5zOSX&GGo6mrX5#lwySN?E@NyB&@3iP8xR1PD}AF0srIFF@I6ri-8n z!x5_?q?P>uAT2ED9l~d->DpMkXCTt8@xqpD+V}p4e-_|fD*d1-!{vNxrc2A7Ii{jQTa8fvr7bMALv(5?EJG==)YjT%ZF@p-#`G?jeMu=N z8WPdf@*QrdbUBr)8`&YH!+*#%-Ph2#v!4)Kg@^I8_n z_h#ij%6$rL^kYQ;v{4t`v%vcks4Pull2^|9y&k*hV3N3!Swvw$vggwAdLADMZymAx z8BB7mE(RZsc;oU!f5l8dSm~RNgRo=xB2?wCw-u5XzV*<+N@$OmQD38eRDyx802!SP|92o&`E#OR<90{h~bsJGGN>`!0jHJh8#g^#-QSv z=}$mytBsZXkO#K%LT0WG%3m{$L+AxS)w6%GjZ5@SOmsYx@EX^vmUIF{AbPN*CSOO= z;zuH%V4~0<&^!l>5$N`~1oH|s@q00&3pPO!KEC6X0xe1=5cA(%I?BKa)W6Dv6opMh zm5fcn_J?hmj*uT+Y$50>fS@+?J%d4RzywHgkvx*DmXwEma2Vb2V8p@95bWK(IAIWD&?4@Dx(Fxw@K4|+3^DfN#c`4UOO5s5Yl{!*;-(vzpx_$Q;?Hx)D*H= zS-F7H-D=54bBCnj9xfgW68qA<==>wcq(YnknI{zySS60qQN6^D=0@{Ad$zS?P>nsz zkhT;v;MN!^kC5w&!MCgevDtT#(wy=s7e|9htm`CBt`PLJuj4XxaIAdnOArxAG{QZ` zvxn_7sXpLhy&V4Haj4miO z2Tj$Ppw&8Iqneo(Zla0(`(DGDfgJQb512Vb-~MW%suze-vRMV>#2_)`RLJd3yLaX% zbiwfk>~1=ID(vEEh`5X3>2?0|?%to%g}~pX<9e`QjKNAryhmF$&8Uv(VkmVID1y zEN>1!eVAkhllf~1;NEouva}BH!&8#JKU}#soF1ooH1J* zkaZi}J=N(!Fm~f*@UuYj%Eg^``Dly}BM`*3C1lA-{3^ODdQwQt_7&V;4VQ0DsjePxr@kG7;%@P_O@x1^4derMwJ3vJVCrX)Nn;8iJy*_r(TwW@L9Q9 zRXMq|#oDhtlGB%qd$#Z8#1v8T=T}$zfV;_g7I%0-tojbGtzpJa_si6dxtUAktY;*8||K%1M*Y z|MI&4sUwLh?I&`$sj%AZFZIJ5p6DTDF027H(`R_j<{zGgn#s!)B*EJv&SD;dVu+wU zp9NnptIq^q3H@ILUt`4WA;QsIVT?xu6g?fh5+pyx@<93f1z;YEUTIY1NgD!O{DcK{8$lq z0qPsizWpaZ?`snFYx4Ph=fYd45r*Weu)S2x22Gl*s5+E2$)^-Sjkur~6*+w)sxTW(n9mowYy!V~ zv>%ASkNRng_tFpkBx5rKhHL(06rrSxO1=XeDyOaV`5#4A8LpRnhPm=1j+C+Cc@1Ct z_*})AnHXO$|IZ=IOk7VhhOj*`z?V3nB<*1EhV4TNX*tLOt*P^jw#G?oO7VWKNGxHf zAX(-jyj#hcAEm-HoxA|4z$7G_g9M=gQ`&_a-8A2+cWwltr8twFJcm+Lhn#+eZ60#u z>mWII@SQot>36bqIUjg=04F{ggHP4M*REsNNk?8tiZc05ZQY_!Ry`pjU$y*Bzj5U42y_4VFi0s_%>7;T%wR3yZ_N|u=9{7FL#@D~RpysIz zY<)6t{a%rHIw6hN$?PhsU<4#eSVdTWi^sjHMSEXm4zuEPO z(4M_--1UwvUB~3i**K(V3Qnx$Wd$2&GA-md=_wv`(>9t8)!kQ-}g+X=J zp3lq$58cAV%U#xn9p(uRN6pWvD3~sJoIN_n6Rf-cu8E}uGT@;4lC~rRG-sC%T99W#=%5nu7$EiiIw3*I&?q={OlMb&A0S6H@;!!mG6 zuH|ipm~3yxHq(IW`x}I1r&nP*>)@^T&;t)|JFfAySZX5Z?DV8QxI5Ii3$qZ#d#B;G z5FG1SP<%MtSa;&X3BwT<3kFw&Q|VxO?aR=)t(5q%eQEQ|&B~=!x3i6theyYgLd+9z z{jY|ezm0mN=jI%LX16^;Jg&UjN1^wJ0q^>(~bHc{yt7q$f zu6d}A;ir9fpiQs^-gYDD@gAVAU*`kOesuMmkHAHK!m39)*Jf%cjqZ~|^bmG9yqtL9 zh_tJ+ou7V2HI^LZ1p<>vAYHhq^p#_Gm*U-xKAldo)|-=Sx7u0i`~+(wGScrt3rB;6 zH4Tn_w(xFuYZ;qNN+>nh&n0ZhYMm9`+NeHe84g3^yJea(`j}?nt8wx{^OM{0WjQ?X zX~PY6{&xb=u-M#jet7~CliMA7JC_NoUHZ;Km=bpZIB;4Zq-XZ)`m0ckXU6uRjATm{ zlBr}cw4fe6l%pNs!0RWL#Tw#=3sQu_xkcL^O|K|3yHgt#O&QqUhH0%6EGnPy<%y$? zyHd@rm#Ibh&R>c_&RH$xmc??O`qemKD(aTnNm@O&^LV_!x!WPj^$-p}{OeHmoLA-t zN!wqV|GXaBSOSV zXonRa-lDP_pvToTP@sAcU!bO|Y7L^)xoN9ERx-(h628?mP@wwt@>jCRtNtT9tUWt; zz3^yFV|nrK6#BTF@ zEB8SRU=feNg}C+w%nx7Jd;bI(c^Oc#BI})gW_FVwZ-68Ag9|3v%iCw-JD;y74?3?D z+im(jF7e?Xao~d2p?p@oi! z0wMGsdWXZLqk+tSr_q^^huj|grdu>fsdKz|`3l}cX zt37|Jd*K2V-GvKeJQt~eZ;bCfX$B5|dFiS?xlr75WA(xXo(pPEmGpfrHgIYmg`I+j zN&FEXZ#XGHmb^NC;`UL-;2$(Uk~n`oAH912$*VRhDPiHU=RBWAdHmW#!zsz`3D<9j zGSl9^eOXomw#v9!6;N}O0guqz`(#bGa-?GWRqlbx@O!;P5Jev-%ys>F$rd znb7TW;jSOFe^Fe1eBpop*hIZi1-qB5XMRxq4|iOC+@{rbc=tZ|4-Nc4AQvTNA-W0zL{dQ?TycWnESsTNhH$%>=ZM7{*utM zbsp1rvXsUcbFHI{IKxP6CM?I2j$?66=ncY14F!?n-;8}|*qmWSJe(#@oR*Ldn(CG> z2Z)~Fr-^tC)W-Fn87{%3BiMAqRwi)(hJYPni798(htB)!O~iE>1%+YKc_v8@b|OYv zl;}9ECLK<%ZIK8~B~2&Nq&(tEEODiExuB1GxBJ=rGhn3((HBBm@;w#{F7mi{Q2tbY z!%bLGHL5s!#d(_W-=M$$GT3XLw@;h2~0>0Z3~H&-=_TukTM*I&5403(!;a+^*K zNo#5j9pvO^k)&g8k~Qm{k7OY`T{2?YGb!sVcQf>E0fYZ1UPi95;p$!8+%mgw*Q`8G zZh%Tg*`8Acdgl;WNpqIoIc~y%Yz->XWBG24B~a#b=6{PocLVVG^(D9S=%!gJdUsF~0H` z{FAb?&W|(mrFfkKSNO%{aIzfn*?OHv5}N8eSN@h$V0DV0nNVUI?3mFtZ{G&XlXIQS zGWaCdIbvV&RF$;y0t`;}nVKp1K=op$`vG5c$Sa;V%?;Oh#A~3S9!lOfkFS$u?k*ja zYtNL6XMJ3sWBa#UYChiU$s~+ru5zD;!M*`LCQc-tF^=*S_@MEJw;sA(7(9L>cy>fu z$UKjaUo_If%FLvE@OcMh}xWTipFMmJ&`{Tu*k>zJXo%iNHJb9xw<&^v|HM)1}=z;&#Bez2` zteM^*j{ptuB%!~^!U8}zGhwoe34G#HeTt27J1fVq#gcPo!l4_f>a6M6q4+~ivU>jm zb=c7s(VP%V+EHYB7Dw#2Jk>azCY{01*2FZ}o}0qx$hRV=xrcu@;!U$3R$Nt8_4L4! z1Tp4sf*h9=7iT<H?XB zsJH^{Oufc+?wekH(VN`FS#C!t^r)3%VW$`D%}bz$ZBx4vdUv8f*`@f-1+D>QK(RZ|b6)*GuD^Z-A}tZWI^kHeZ6gvM%yy=QTs} zIvN8{cEfI+uc#_|a)c88#0Hk)LQdL-geANQlmLb9-(AUc)K{3EeaSdYXA!8OcG z9`HoQAD@qT*#bVd3R$20uMA3Q4B%>6p_)(`KhG!aXKJ=mpv#C~r%yNu#4VD-8C$pe zMApjR0Dpl70GD0H+%m{&jf}edGDAMm2Sd{n%g^S=azp zqAGd9?fj?PFDVcgK17^sDP`S4@`32pzbNvw-!y}1gu-v#eAB#78@QG>^1+rjT^eC` ziW0c_caXY2H$Ma1&-V@p)akQ^>47cc=2ph)N|UfS8{-WX4Yp`*uRvv+j9OcGyEc#x zH5z}6hfCG60O$8Yv>jMEnvx ztY7~8fh_Dp=wGruD_`b0c=oy>cB$7`LxE7c^B62{&8*@9ln)sSg(1wkuK<`N8+EJt zUU>uYIF>X@=pc?Yk&Z|fO$Vk;4>;Kp$o{?|Y$i0@04q^?;nOnDdCvQ0k1VW^O66e z+M9zLwug}(2vbt~x7BD~0#@TS)abZONj1;G8M_4)kC8f+ofv7Jnm=m{K5pY)%U(3; zAo+Y=uIQWI>4xN|3bTe?AmgLC%$D%(ch=u0`}ghj3kWC%#>hvV(f^SZJbn)>f$x>4 zz<=f@wjjXw22L@ae-EU8pHMsdLWt7)FE?oZ4&cAp6=KZHBmQM5HD~f~(f$7*BMAbJ z!P>X(L_YcVRs6la8&F=PvqdclNdG>pS&zcjG#0R7$#R?DTL0S+DJi5GAyTiyJJJy3kL}fJ%Wj^IOtZ?@eTjO04&dqe1S8$ZZ zUBIl2i=$dW&)_C4GZ|+0lKou%3NTqym#I0tB1%-|xL8sCdHUkslW)!;IlK;b%L8)K zB5F~-u6NO$o83G2KP+71+>cEFaXnAT)sTBVopQSP1ens|;;6{BSud}64J(Rph1G#R zSIpsvH^!qcbG)Lz|FfjDp}6?5zH}oCn|^Kb2?Z#ynqieJEK*dQ?FT1W7#f5&qX#7J zwGbU}@BSsaO|?Wz(tZ4$nxjkPVT(3oVe2CdFu#;H7|~H2K{BKzTf=;nJde_XXGXB= zDknosVt7~EssII*BrAYAIEi9|rDpNy5}hr)-!-RbSbHPRf9qQMBw1MdGx*Rg@neem z8|BV`lqrAqhznqG?(=|Vk=9f|^WPT#f~!Ms@$rk-Y>CXUlw>x~u?nTnEz_o z6an%b<8IGz2=Wp$T@MHfP2HZ-JS)!o@^di&SIhQ}<_kr^lZnzib zzX3yv)djSbfQN-@*CPXgJ+NbFHaOOV@N)xHD@s_J)G_@BOyF3^clx2iv*zA0XF2I*#IVd4=N zxY$1NXZ-3PIXG39Qup}^+Y9qV?Qj3`kLd~U(8V;*r+;9!{R8uGfGt)9EM5UpZRuhG zCNNEy|I^u&7JUqMpI@j5ctRF7hD4uPc#?&kO_90ooL-`Hy^^W-N37;v1PpHVJz{(U zc0XGGA|(7rU|x=@{xg0+2xt#9KqMf1!0PJmWz%0G^?hi~D0H(>Jqwb98 zX0UBxOlYzQFhP|-hcdvAT+QAz&;0ytcl4Hms=^d{)UZH7b<{vPId^c4ZSs~vz6x-U zI?(wN2;bS{e%m~=Idad?=wE^MXMCO^Agle?ZP(O+)j>c{LTLWwpT$2gyLDdq17R8f zQ;qRfiabY1OSb_hrI;xYq0TxnkwA@^tN|i4Iv%@u3CNu4&CP({Kq3av*0iq!>F_L2 z^N;9HxCIzYGmlfe7}A0(u`VwC_TPaP0bF}g0*v1u#tbHYYzD6(-jXr@%RilfZw00k zc>nZ(r2&}SE2bm@q{zZjkTNtDRlo#e**6&fM5qT4EFAoqxNBG0x?OR&qtC#~k67G4 zjA((~KOCpC1NR6#0}N*RiH2PUs)l8Unbp4njUKr6#R3?=icEtGT@Ps)1YM%M`!D|_ ziUC$P3ab8B7uW!R*|x#%&>3j8m_nH9|0x9j6oP*W!9RuIpF;2tBlv&E2;Nlerc2Ya+6r`qTeW$qU|OwMJIW3qR}Ci8rpYF*(}gPkDcx zURh7K>m)`rgFElOon!h#-(IwH{9E+~7?u0LoP@suS(rmXjaA2c$oVM?d(BkxB^s%z zo%=MLH0LJ*-%9uDX2AV>F<;vtcj!q|#0m3`(mmqsi)Ap&N?z~oh5qtrnH9a?o!c$7 zG;AN67Oly`tU8R7Zu{9)p^0j1?yco9Q^EbV#eT&*n$2A@f~uq*RAMMDqwUD#Hx03U z@#2VJo1dy&5Pc5GQkNh9`#|#*q>gi()bw6zzYHR2YhWU8xtHU3Xj`bzxp z0iVl4*m`4up8QMPNY0t7fahi`!7SP3;jS!UtlBfC(`2DeWl^RCXLhvb*@L>u_G0`( zu^;2TRfoJCfFutX2_?YaH3u=!(S#qf-W->I{7m<*P1vpet!e8c8=uYFDH~gwO%sGE zO20?~aVdY)mqH|wp4nZi9Jz5NVg*wv)VY@<1!4RumAji|YAx*dU~ zR_yynR;29C`_?TM!BTW!OpOG zV-P#PXm%oIO(IC~(?+NI$vbJ^V1@1Fama#<>SXWD5y}3A)Xaw1ng}wOV_qZc-L6-B zaQa@MZPj@JSat)I6S z<=Y-j$xAul(!KtopbB?9Af(mXKkThjw6es%T~4X01~{Zo^eCs^_YHehhh?MS9KMpW z+cuOO*@ikS&;-s6wO}5%V4V!yMTVvX(MFjwGTKBgXQ6fIV)fGE5NmFw^a*uyP#P$K ze!P>xCzmIosr2@VkFuj;6gl4+y0iajOKW+3sZaaMx64 zJm+&&YF(o^#im<+Bk2zAxaMXYo??jo^5 z4&BR&#n&2&u01d)ZL)#l)t9(yq`VK4g-Oq7VyYJD{iD(}bWt8V&J0eGcyhzWqAA^R zqw|^+ZtIZ+ZjO~njwA)EOO%oKRBu* zHGeAF9gf9P(=b94#_a4jTOy@v+G0FNU=0g7?YdG#_eaq>>&iXOEuobgci~vH=lIc* zOXGd4VQ(3BFT|2b&%b0AMz9_5`NfjlvjO~Bs|Ib&!0D6@gI zmaFX9x@0>(4o!n-PgKOn=lDX@C}oR+EY-w3zQ$WUs5yq#YZs!nMvcOhvc>5ZPIQQq z(sfA>d#Ix+mT~c!UZFeX&ywk*3Iw{S*7$W_`-Qt5Ia8L9!Ulq0fIt1-Q_c-*oFhKA zdZ2sR#2uXVo4lQw}SRH6Ww@ zwn5SlU42fUr$OiPmHB}N>pZJnM}gtNHgmnrt)bsD@4mbbS_|oasOOA-K(j#Y=e}Jg zmcve0Q6BZ+Ae6e2=XP0R_-zBej|;Ui$JsFR?c3IFUSH4EriXG2;hHyO7K7?prhe=W zWZYjr^2j`yXlPOlTHf#r>6g?&kLWQ)APjEh#_x;N*xjPNRpHf?Ci3xlyuO;_big6! zmnO5|TkHCw2}Mo1xpTGLJ=?#RcOCaAWw!8L1tqVY%*QSCKKs#dHEEaH ze>L!t2xwu9J4@P0&aUSyi0N8&J98 z&A@~t=??wUy7^Q1Qc}+>@A;*K5>On2t1MZ;Q!6)g3#V-9N;5>7 zGx{j9?TPRwT71PdUq`IG=QORTQQ3$u?vzE6Q*iP)*D)@rn7m!9QP-s5JKn78`!Tn% z(ag~}qG$s8GJCTbz=_dJ;D&2({`gK6IwDvbM2ZUv7WVaI_4%zPcXF1tDRI>wUWc)v z#FHCyt3o%ksD!ZJOJ|l(Lo0wLSCKUb&XEx}v93Pd;K6&oL2Sb**EEwpV^v%9UJxdL zmRDnxzzikX(LM3fX%H>Vl#9mKn0>IZBdUj(k)dIWeU>!{6LL_F;)N4Y3u9qT^!;@U z%oI^*awK8nd6)&FL(v{y_OPBhDzj%!=3T9I;a;Ti**XgMi?_Sja1KXaB6g@ zeSaQ=>>e}Pq22TrURm{CaA}NZrBLTCUV`|o1^Q&J@Rx-mM`$XH2q%$kEtZBjWpl{E zaG3y#T2$B~9pYe5dLde=w{gP8xv4WzSJs1lK%Lqj# z@8T6;izyuTkb}pi^%*KwA!sptqrU!_!40tjOTzt37jXy6cF+(E33KBmkdNGKD4Q5w zaz*$jtX1$sKy8TinWo*v{>%`kBNq18v0Um0r@v>^q&aS51{K!smNt?p>Y?Idhc)*Y zo3dLpK%t!NzdxV$E%WQN9@nb!ZZgJ7G)z>CU4EaTSlJ6c7M5>WKJ9e6dm1q-d~=;P zia{#TH{>(5*!ORD#KsHR=APs{-=lkY#$2d)3^+(}t5~+Tbju!O!IuWijgLh?+ugJLNHSwdh|c!|8C=Omx~C zUqa$7#U}C<{*=&2ER&qCTzZsjaJ)?Vz1}IkI$}ep($SCGa1@i=sW79YGgD0N9-(Ru zV~7}CnHGj9qRH2P`S0e|58LaXQDn7101|C<+4cO%!Bkwou-;6-nr-7;q@{pyu~@U4@tFbY_*XPdHEcNA#4q11F73DwBp#Ly$Q#t0AlGo@P-~lx)o{p2U+M$r?*xj2dvZ;< zr^k`92qRTQK*4Xf5}CbNW<}o&sHlY`Su5e)1jfQ}cIgOF?V|6MXMt4_*Ec&W8n^4% z8$5WznmTH~`aY8T5vx}4@-IpuR89Oy^j4wySH1IXjo;x^%Nshc7B#4Lt1aMVO`AV| zj}{hc@5GlR8ois{eE30e(rh#p2-Ts0t!DEkP85d}4uzJfbCi4^0LC)iaK2j^^SV0K zIL{)TVd7lR#x<|=EyR3-SNq(&4XqA4i*CgvBLd)8t4ZO5g`y890uMVLY5ZmMbstSL zZJah1mwcx--jah#L*pcDs>k%`x~I$Yl)CrHBV%*QhL%!Mjpc`4AagPQyoZVYUD$*o zxh~aAIz`4!FC40BPh*J6Wuz6K3CT{n>J*<$soDPQPCz~_DMo+n*Z|cSG%`x>ig1-& zR**#W&uJYL4! z!B5Tq(8y6rtZBqYW1FHMth`2g;aHX1(5O!^K?M2v*%_5TT(y8R6DS~L~L2c#KD|j zvc>`oKB6(c;pCK$JkZp$*b8Z$r6 za7j)+qntc!L5K!PesaG2wXggA*LiyU{_vn#fsjIy!5Vh(qi1#Dn%CzN&yg(dBj)-5 zxczc;wm+z&FO-An3uAxR7G5l3Y)A=Y zy`yG&rs`koTGm2Ne~R|s|CY~MNDp$R>b1dox1~@{{meT*czNU;k1Q4l+BHIqIbQJN z#_U*`5`2on!}&O?8TDPBML7<~=z{#+^rKlBzFE}Z$Uu#;pS_9 zR2-ZG3!hLJ92k0`Wf%l{A{6i_wkzCR}VC?VEi zkAli@$a_lq(erT=+3~@No;*iPU8)BNYZ?nG8I7VGPWJB#;xaV<*+C2-3yWp0@0pHm z4GrRRnnsuML|eCqg|w` zI~lF0%tBWQ&i)r3pKI1j-e$?b!>KZbP*>VdIL9IEUE6LEytx6YG$TzIEEIw97V=Ve zmV_ZaW_9I@t?z3rY!#tn9;bm^AhtvuK7DCa-;$*uM1r2|L{*k%Bsk08)`idDDt~eG zShQRI&f}dn>~0Fxq4lE`0W*^3&{QLbjkr7hQ37|wS_FN&lnmN6L`x}BhkCxRm?O)@ zss(&m=%pSgM>~ZT(E#SH$`mSm>gpFSA-2B10cAFN*Ea^q;VMFoCcG}>l^y2M+!c)W zAy~90N|aSe^hrXoBG7r6!1=>c*H<8CeVyDnnF8wAMr|pRP6vB%NGR^(cz5XdRO-1f z2-~@QhA{GDs=I=NV^Rg{NgaM#SK4NN*uspAdKI;8A+HBokTId!IQ0Qa6}1M)@7Mi9pt zr-Z;q9}9GKrfu2fh-Spkn3~Syo|+0VvM^aw6qMMysRtIXE_%H#og9*jim5-X$a!(~zo8}C2&%TyZvzhK(q@^PF|o1wZj#LWK_cNn82nb}L>ZrLu@K_@gOh3`@lE&+v3vb z_YEtfnI`jOJG}RxEmC=FM>$t&Zam}C#<#G>C+lyf{BzZ3;EL~HGlcnxC&yk;M~TsT+haa4Dx@7G#}WEHk4pIw@QNV zj7i}6a<)X=-H$+oYQ!PwsnTF#jErst!uv9ugJ|S&5YFN5BO6vU6-l|}6T9y?U{Lo( zuW4FmC5alpwxe!|q|{A&%E6 zHcsYiOe!XcYV|S77K8-rp{7WQmiT1_WNLr#cZz?Zt z7KAw>Zi0HTU3vAlBai9$8p8N)E{>x{?4@blp>`ays26A0Gy-+uJACC=cvw{VV1vAT zMryFIt;BnV@K}h^#8KZ1yVFq@pIMh9_r21X2G>(o&!a!YVXq+DyJiL=_8LO5sN>@(dJca0PtIeH4SE_zqDGS4y zRx+WEGY%KSnBMGAoyh$FiglxpE~jH*`80;yjBfoEv;8l`m>DAeQa=xlip6dFCPsaB z)FAHa*-)Fh4%R}J7CEO*x@0niCD|zA(q?Cp_&qt&n2#^BGSJBxDwf;$J73S9npHCr zzuO@oxni~-oPW2+c^YUAn3{y-#5#&6KF_K5e zOvj`dYgg#ZsqHG6K0*co;x&cpj<3VSZ#OCzdLLR^M3{B%`=AUh6s|Pe)kTRgB=$Y^ zWnU$I&)#a@`*(hId2RBh{sivTrk^cy+0^EgLW&x3HnAG;6VvndFV+5MWH=(f9L97NX^x>ua-n5BNXK!Wr< zEOB{P6IM)XCKJ-a_dF-yU0>1Tt@MD4(%a3pLcw|9ofzyNF|hr9PDD@8-4gf%PHU)6 z?x_6F*7W_2`d%;_%l=626!yEjmJFAIR&82Rl;rmJadAm^4t9pEHlQV95_z6wL0D*p z)S23aG)(UD!}quM@)XW&m4<3(3Cqtyh}2dg7@i_TP+)>($THlPL^w@kIi+^Q4DC%@ zt=Cu2e`RGi8apkM3bvArJ_q8%2@n_?Wh9Z)_ef~w` z=8W_g=@d54Ud=UCL5^gpqedDHeKo1$WHVlxf7tEiDeIl-v)e_+4>dAO7BoYopG+nN zUf=r8oyeEzaCx!o+SoWrqj0w#(LXOjKNXmd2&b%4TQxEgKQ)#1pX{0VVz|BC?yeQ4 zCqW+(z%trRwoXx=LfJ6Iybj9g z;gHvQCN3gpL?fg7GVS)2>&+zsr43BZ5}MWL?aWz9EzQlAO69%uYs--C&($_{rYDw^}U zL6dIhXQkkJ{OHR23>3Y`ZAy!v76BN^?Gsn-_#3KB4z(8|k%-4R# z)?`$rk#!34u{Bg#y6Z>M=dv{1PU%b#Zb2ea#z2~}NgvDsVjM>}H|P@F2VDH6xXzgdw{u{ zzgr@wDT=HtJ7qLnrXSOx?;=^%zex?Orzk?qRvx*b}>%c4u=Js`KryHqowsbT+4+Q z%6l6@zlA1sk+eEG&g z7u(m(_=;Yh!|yB9x&+$aC2AIPG7B!)&=HuU1TDb(=3lm)R&obV)x646vWRi5&cP|0 zwpOg!veYEw(x7M>X(|0J?@5zcWL}9u&Ki|M&6s&k@61kHpcXlY+|UQy{ih``Paaz8 z2q8>huX@D~TlPT^>6Mz(So+nbMO&8i;30Mq>46!IG>6{1?Cea8bdwoc9@#it%+!q` z>|4s5-O(3e2;kztxo+cBEPeBJJT+*a*Zv4|7&!501L1XSWh5T##`#QYfX9-yUbiBk zE|4ewD|X>XwtqahU=(ZjI^CmCI)RT<-uMOT>LF86iAGv}YiRsoikJ=}Qe(HxD6yeB zDsZbisEJ+%PNy7348ulty|qxNkQ6HsQT8rE_HaNTD;dAFw#2{$n%_JF>ho}}niMJ) z_66<;IpJTG2RMKW+4sI5n&P&1Ln3Q3>lr(B;f>mT{R4Tezedj6>M{tvec@EayeFTD zn>t3@jsr=JJLb4?iGkxTsV2?25YzZ6;x*9jVxVx*sDqm$`fNDpyUWs#)#NmXPfz`| zVR=8@{j#%NmfDU^bmeGhEzrilb4tPc3P(lu5;3My12bN^L#ObRF+gG+8K8(Hj2+8p zQJzG2RqwZK*L`bxL?%56pU+6Y+wjzRJGKPVK;AH3>eOuqpbzkOV8aWtKhVY6vC|cf zf0(ZTDh1~^=p?8k!!$!d5BhMs8&m&DlGW7#b*e#mh)YkG$XU6qf1-lk(QBC8_8z1k zsQW`EZ{HO%#kBOxVwU^ps7)gZa*lDPBbU5E zk3`Ag2E6pnOXG9oHeanzaB|P&2OOki-UW3qBy*N?ME7~JEq}F^9Qd{senQnf?6KJR zRBW}T5@s3n-DeUsJ$yxRnPIXuv^#0FtP~7T^AE)mjPzvH`61D4YWx@JRdRw{I6bO` zduK+YYp!!Qz^X7ty-W*EnG^joC|hDh`bR0dpxmfvfreW%WehLR4rIji*%b4msW!(7 zXiZ_g(KEF*UVRwJ(p>P}s?`kv$%i2BoLmiOEZI6(zfrDhwXd5y;bv3<@aB%3dzKP- zWJ5e2Ys;rC zb%ANd?wa|y6)mkNVZ!?7`%ZlP65KR_S;p@DvIaGATS8X6^*yxSD$C_PQ#z&dPsoAw zy=>y%G%h`)9df>2=+&CXxb~f0un!=b67+NrOHuZ4CQj_Yi`a!JEswmG=CHl8-X-1FsA)CH zb>oDpW7fa%N2-%h^0|V!cd8C}{Vmf>9J3 zjhLdQg8B-6Ra^I~+0s!Eawl|7Z|(%VYfd8bCTA4Bh_=qlX`XS0^?q-)n|Ay9rK{`Y zC#?Rmfay#zHC0;=AATdahOX~f4EI|qsf`JL7LnJ-5`vS-(p$dP5DhTIRNHRnj zUy<@%ll+!%k^@)wED~cyw*h)6)j?ZW{FgM)L#V(FPx$Dm^=ivKoEVe6tCMSDC^*F|ZwpBVJbcYU%xR7r-!$OCaB#@F^UDZLyasiu~2b-_2-=d{` zg-3G2F0*`$v6(gYDqE9MU?{AT(@_`zZ-(|Zi;TbOx_nT6)OERi3#fOZIT)%kyK_5N ztuSq?FZLxBxtRHps_d?CY>DgA^PD{KV#i#fK60(K6wc~v;7@Ai8zb77mVm|-6orpx z#=0}M(ijITj5mh`UC(0awG3BF8W?HEmnjl+$R8j&shCkhiUwu@vxmFojfq+Z>@ISy zH{-)8#YTeKf1dyxh|V=nY1k`(J!-PwvPx`KLrtXw>b!&B(EdCM*AFl!{XQ9SYzd+2b#(9 z`;+;BCQ~Dw2N%nRQ|~&AjYYD=sDHx5_VUV5%ZLn}vtRb3q3jp3s_?TZ9c{ys04d>z zCv-6wNh4Q%$x@{>NH1PL71ra6+P!zDO4C9= z9#$GOThJ`=a!Tc{5&x#U@kO^@sKclYr(Vw9I3_Bb$^(c1zuEQ zw;bCE1c72?FdHmgSI>mHe1`t&AQJ2=YiQ{+0!3l7Woi9O!`$*)Hch*;pZy)Q8d`LM za})xNIo~22-0NqhHj2jkWvkw=j`mBiGP?w7g)Q{4swZc5?rr>(oOVB)GyGQ6uqD%j z50qNmDQ7sT-`pJ&lxl4>b*b$;k6PQUvM4Ou`^l2)$w5pyDd+eL{a0;~9?V-|{gV`< zso1m%q!6{>`q4}iJVD4Ss}I8mT3!^}DwTs`e(14+XXEiHeJKG2-o&g>T=iCRLYe#P z)?95ZnN@T9ZB!+g_be<>!XzN3JdKotlN&6xdFM=uDzj-9b+nYi)hIqs7$5(1NX~s4 zuArQwgU1Q~va|0l`L(y9LOPd9uyie`446p@fsb#*-W~$iXr9;{C?axVFjRi0%abol zo6yC*BK+1?N=>~NsZ9B$EK5l5%c-%8lv1k-qVm;{Ldmf?=ds6)cw;cfjftsVMEL+SbdnLrR}&G#;MFL?!aA{$5%8Q0!JI*SAJ%nqhCz7mEL|;L2o9oeg=obz`(U zRL?5JB?evQaTO>gZrC8>t4@vA3#_t2(9llRrV067f0#vylpG&iB*eE$zt~GQp5Zj3 zYCP_9QMP?;c6$j66J}zm!%j7H0vd!G*5{KQ9{C_IA|1r)YBqUcBTs%<2c~M=eppww zoLlyDj+h^9vgHKL&ymI-1+bV+X%R zw!{|mYWc0x6n3w;R`r*W4Nu(;1No`~2c7{hW5{U*H;Dart!>^%)S?IXGvAWjh|1x=h^`h zzLpvkbjWHj=&ucu5;k;owjPz|gr54>#yKoE9}3qz6lNRx0PGV19HYWv!;@820bXYyb=CjPOZ?zk_HQi#s>yY&8P}sLtl>YN0 z`cofr@7g$kOdsqo%sv)Nx&?yz*reUa+%R_c+zpIA3a-xF;+%_IoA3k=Yn9e~U+Cj~ zz)o%QJ~(d$Q3_s|&tKbLKHx`B+v9(odTz`^gZ*5@dL`cjncRISSwLskpJ0{Zr;dEY z$t^H`s)_1RBEE3WoaF&MoHlT`F%l1D?~xoOsYlK-GD$9>T&`h5Q~l*>zlfH^0v(*1 z-fty8&0gxrvmaWT4=t%?YQx(r`b%|-;(Vj11N^=AHQgpSM7x?hoAsnZfPbUm6+PLFk==o2pe!+P(vMVtJ&)2~Po7bFtyB0=97Jmz+U3}%x< zg+7gK2zplK;F7BwxR3|AVAvWc*|1z_a!HIcI*{vY^k5|i$GY$h%Z@YMC1Ee#8_xLC zoX+&&0YB+i$63&V-2Gi5W$ z_*kn!c}T>_Q}GUKECXq#vTD_p?QSUtTXp`Eofv6p&{$G{&bJ3{I_RB6CzfG?z7|0l z(;iu#<|f^hJwrdRA~6k*V1*N{7{TxydU5#V3Gd}d(Ah) zpe@E=PJXCff9LHqe~x<}#>}(gHO&)d5%Bdb*%LEH?;DtPm;3oXCOwR%ZKy=@2RP`< z#eQRfXJR&kW6shI$(V1oYrd{`Ow#T7%?4HawZ$-o6%A`P57^93KKTym*pjBx^jFjh ztCk#IHR0NyboM1?ZhUy1W=r*9JjP{eWQZuiXXj_$QnIa9TsUTLzqn&9qUbCBJ9Ec^ zqL`_aA?0{mdo`bi&Y4+%J-;D0wvOX@Ihwzrj0^cYIKDwx`?xo`$&+Wu`5l2zg7#-F zZe{;*3Cwlkl3X88#g)z~?!LsK*>q=XV;~&IVwZ3kRKr)&zL)tNQSY{$o~2L@$M&w* z+PqT%lw@YhlQ2oHf+0Hnii4WP-?sPbC!X)0g0jDleU_f?POd$fNo1``dpo}NR=xiD z&NngBjDnqltlW}hCJeoD*Z^$xNm#`B@pDVDgCWGI^V>bu9Id#Jf%DGDG0Iij2Ju`onwjOl|)qn4(9$#!=3Qz69;G<>|0&NhL z)y$hkwRTbQV>?0`odztE!S2fthL|LNM?(|Wu{0%c0A>A=SDG851d1PFKgK@+cBM{L zj+=@F3PWIIF)H&0Vw$rc04zbPU>32GJ?thyFDz*kIIa!ux`rIyU+ADsH7#Dk?w#Cb zr@w0HG?7tdZshI*h`<;(^7V3I9V9Bt-$`#bN+f6lXrYKm7OW`E*WM}#w)M-CSiZ6O z$N*=mS3@vip7^lw*|=#Q=7wb)hL4slzBjNCpko+XIVbHF;x07?ICCdPl?NC292gEA zaHDsq!WG0Q8#8`R&qZb(#?+k#dJ76v|u6+4G^^Y}pAhAFQE>J5k~I%AXt@Kk6sqe!3Bt zKh*vr#4hN?GfnT4fXTW8H7yi!3$)I;a8@>^e%&e?T#Rf^x|zBncX)H>fY8V4&rC8R zck96qd4PbQrWWnB!WK)B?Wq`(EyO)bXh0y7dCD()jz)aFv`&j6mlsVenEamJLFLdy z9|zE{H1HD~ZL*&3VENf|CL&*qM9huNAYM%y>pl0VdhWVk_dU#0!zrjTEGO;f7pdA zuTZVv{*)SR;3m5Iviy7S(z-_4@B~h$dLV6uT>)Y&CKEhmJSXTZHN;@JA5=0s%+Rz# zR4knqMfb@LSvbcpiLG&B_$Sn37#@@!y!asZH2w5u^x6Y>S&<^XN63fMa=FPY4eo>b zWRmtMy&m&A1yNQ1F}kuGLYSal8yiRJZI}DlNMn==bCqGnC8rQ$9=#pCpzEAM>>Kv( zlNk>md*V?BsAu~wh9ZyvLab}fW<lbfU!P+c%fvA(!2_{EN$8 z)(Od(h^D?;>hZJl&IB~pKPR4SD~e_wD}7i;%PPMT8m^`}uU6Qf+d3QfXe&QT1sOj! ziSIpw-rt^Pw0Cs0_4j{ZrRTuHs_qLo%EZe5Zo9?8ECdqheNfB|aEz1v*nd7nDP+Pc zlldk1=;VU8R(EY_)k5uQXH;d(jl`bHVbkAm;hS9w_?IyEG;gIr2ccSCLGFD~VK`=1 zq?AYiLW0whrXy+D)wiwV!of-K?a~f6!Z;K%bY?)u~Zl1;L@}A4LNu1v!$3*Jd z*m2%>e2SOconiWsBzkb0WY*L4F(ad>W-Ox~?QBnG$LTjhHjH_KR(kH_YYSX$HRaxw z_@u{Nw}5N>>US-Hd`{>=;U5F8IYsH$i;SmeG;g1FiCS;}Q~7C8^w*+f{{sF*Z7RjE zJ4kHesv2ut2qt2{EUws;Ejr+>sh-`4|EZqyMq9>IpPH*%B;de#Q}@>LBdYIz3xPF` zUS@@E0sf{z_bDCw>JG28n$7fAr|yAbR|k(c=_f{?O6^wa@w#KDvtP<&uB7FykSQm8 z3R}K6VqWV!xwdS3`t>OD z>H{>dKwQL{CxkbWm=!TBzti|( zmw)@ZO-w%VsCaoEE-Y_4nW<8K_+CZ(KK5>xZD)5vN%}Qk^LnaN2ej!OC^1Qs%}|Ic z<6P%2PF8s1L8xhG)3tKJsP0;o$(5VxmQ%HVyOFN=M2^X?1s6xXV9}2oT_{E374LsG zCAL(nx$*w_6Nz`qj6M=$w)Ny|G;uhj)%JX~d`!c3`L?hA@5Zi?i5joo>*DUDzD0A# zGgNbPcmkj9FPO%_-@{=SmR%D%pSAquBcMyaPaf9WRvf2+a=po|Z2_b9le_KM&A7Q` z&1>d_^MLlMzSCC?z>Q`z%q(#*`K0aQy6ucg8&Ch14^?epM~eGGTq72?sAbirp6_o# znks>OzmqBwCdc#Mln;%c#_Nx&528mkB{sRHT$NdkLZ@ajo8}5Q=kRDIMM}CK9G$Jc z9WF$21f20mo|-i#p7|*?I-PErv&JO>MPHC#eAb@1yUX5tH7qUF{@d*}ynAZSThC}b zti6PsZ{Vr+g6qAM{@RCn<&)vvcLsCqO*sj%%_;m#kDedT2I&6d$qT4W?D_h4{YvBY zEuZ$8`%uuDp3D9X&gP_ofh$wu3`=vuJnKVXzI5S}X+;jQ-PhaHOnWXf8ab1r{g3g? ze;7lfc=-XjUo8J-6jdvJ=5}e|oJ{y3Q_5Avwc~ zi&n2+{})(tT5_v(7{F)td-7g<3)lTC^(Du)!!8T#PcB~QRg1uEEM;cH&$;klJ#N@2d$2aoXyf0k{ zjI8h#09of==QN*wPHoH}0)gKJZV4LXR(0#GTsLsl^qKk!_Pp-j&V3M(6u#tecAIe8 z8o!Ktb*H$0tW~@^d4% zU<+uQ;|wNun$L})-Cu1|2hl3dJ7@O2xK}zZ21@e|HQjf{Ul-heQTcvwrTV?;DfGmX z-x2JB93U!o?znK4|F%9t+$PC@r_g2(E$qLsFTNp0I{GFJxT5cNXS-NgDPjOX3~ena zW`rFAivh}X6#qS+`G{JcSYBRlIqGZGYyP6tNP3OJll&gHYLl=3t<_f-ZgnhOOLA2NVL8F;?&Rv4`vP5-yprQ7S0l%sMo__BWGtLIULN*7W{JI!y&QDKmT7J zZp&+h>M7Z3R!;RC*{~w(4H(JOpx|<*|OG`{$gjL z8Du}5uwf{MMg&g&j!orhbZV$xYw`flQJ5inYSFa!cA7E!4;qY-W+ULcRtAB2II(}R z9X8vc9s)HxFn2%MO2C&|qCNDz8Lo{WR*?espZMLxAyzg#$`Ku5CGth( zJUW~_svH>X^R-F>gN)o*^>nNy;6I2A000=a(;C-bH;YIYO1PrOR#Ui|#6P#U4O=Ct zmog89DpJkav*!L`K;XMbyLq{Zn}EFTB!XpkHK3h?}cF2jnuAr?_JS2?J{ zA`hw{{tc9HP`^yAIV}?Ou&LZ`i{8MhaT3mNXE~-~7asU_EGd$m{07n)Q~GEnOh$C} z`P`7&f1jG|C*`Fj|0de>)hB7e5coMtYd$sLj|qi^^`aDHNe7ic{+om3g$E9p&X8DE zaRDaeiq~5^kT6Z*#0YVC#iwn|I_Sb#eBauoLy9SOeXJ*R0^_WMuQZ~>(uQqrg=?MR zhF?rz5v3IcF!imo2u1O;>SnjfJl4MruuYh6_G>j&oS=kN#DC zk*O_3^|d*NuiF4unN+6&}m;lt1KoJd*;R#=2cprjQ-^-=^*%$3-*_|)Q9Ih)_=HFY+56oPyaE|Wx!VE z?fq={_q6RHI65aaeA3=$glAlHf1@dRmlZAZaetGoCc;RDXu%zmt58|Qr(aUof= zrJ)wm*+C%j6pZRMW53I$lJ7$0Q;QMn!f0Jq)}YCTmZ)v>H94REO#1)!0*o8lRND*`hsMF1{Gh-j^jxijaNgse<8>oswt|7akRc zM3t?`LF?QLnoBW``JV&ToSGP_`T$i0QDTAYg*F(n8$yUHKn-fAB<$R#I2fmRgx1i} zo%aF)RnegE@?n$nu)(h;g$m@zzi|V^f_9h*%uF~h< zR8@{o-{yh_LlzR+hr!vZitB66VMaSdJn^w-t?xlf9nuh7O1LTy0#PIA^ z;^lJ`uz(Yv$?E|8qkOsO`wvT1TmlpEQv!{j|Ja0bNU#-N6t&jB>>lm1LBpb0QDiX* zw+X_-dYjF%^*Po+FZ~H7riG_ z#Re@JuIz><>b=lYEq}qLGW;&w%b*D~lK+4YJRx^K*C38*?Q$MY9NfQWN;5X%%2{`` z`II+CcLgS-U`m6$Sqo=S!EL8(L&EZl`vYbw)O?Q??OahyOYEsHtyJ)1!lKdpYH^K+ z7})$Ve2zE_U7ScfM_0+RFsQ8=?dgvN4IrcBN=lXi(*5;PNtEA($&q*}MIc2V$jQ_iHNx zSy-2I%jPD0yMBWk--Nm+8s(J2cLLN%pqMEU{GFW=+U#2O9W-t`3Lr40xgy3FzypJ1 zx$B~AUeooIRHNsViCr^1Uwq4^ok| ?}2|ggReHihZrraykfE+=*3wQVn;}b(YL* ziFRRJxIJ9gWcFb3OZ4<0O)UPctR%$vCO%cw!~a!^rRapmB_va$A6J1Cmse?+fZhcN^*!vmP%!$&)MdinMb{4CZu>sY{`Nb4*1G~pA^B4gEdJIOlzqp&g zi=3?VOei}#I%5dn^0U(x2A)tlC<4PddC;CzFmA<6yOr6>{hEw`$Y3#}_C(X+SlQ8i znfRtN_8Rkmt(pV8kn%kYGd5%JGcH-eCuw7$+_4EYf9JcwRf@X&k3$HJbhVJOHsORj zb^LNL<*tWXBOOqdEe~Rq6H`4&DmMAB*X|rK&Q~6;`xfjV!Kn}EkEem{g<|H8uK}|Z zwRQAj4*xQ(zC%p}rC2wX7G!j?F6gQNeOsYt7qT{{Q2qRc-+qC3#SeDHKJ{UvD2!$i zMy41Q?WrobN;7XE6~Qce7-0c@NSB@^oUovKa>PVeBr<+_!x)#IKfr`~M)4vIx>X+5 zY)=55w3Tbpt2NmSo7Km$UIIOCE*(3)9uQHgSj9Hfnl2mTm*1!VEo#-{sW-#pu_aYO zsE)StaoX~?1X&P{?}u?oxRqJ8YE(l+wI`kVW^8yHA?`)yZT?)T}7V6<2()GSePj#g#fg&$3g-Qns$|bDCw)~Wl|X$lF4?qk*mI$ zi{1}IkdGC`%Y_08c?S}+5VdZQ&$W2^kDkZ*d1GGsUd}bs5!&JR-7hTsc1?L@M#pcN zlMls=meV=2*D4vQeR9_KHU5jpUDOt0LoXl=rxYe74d>))O&$}922KjmwLs)dnrXQY zeHs~8j_3H?g_$n+%^^esXfz;>HsEAwJI2h2YDqLmfu&WU!f)TvV){bXMabnUttA04 zbUt+Ym!*|Vh>b1@b?yhg0FO>-wuuD#Ph|qS_)HDKNYS_;2e2ktjK&Y2S($W1;#d@U z-#avbP?#~-X0pN56U429kKh<(*23IshA6Lg=OAw2JNglN97lu2R-Jy7`O<4Fzl%_#s9h`#3774G?J3arK&3`L zyv~QQZ}Z?olqN?ue0u6R3nLU5g5?H^-uByNcwC4AQz%&qBE}Gtm@w)x{MAx5eHw$6 zSfZSONa)xL8g|1L6q%QY&zKxRp1QAh<4v~~24EB<2z1DdKWx@n*0Z#iSD)3gear8R zwQ*`BIYPJGMrffpKsEFvnD#D>J(;~ebcOF_bUK10>gb1BZK>ltX)ku$=Okqo`j|?k z)YzXnm?x7tjWv|dli`yrKna-{uH`%ZXn~lF$yUH<=cT?PWr~41B|~X2>iIWM51mAR z(=s_}K{~4Gag-dn_-R_)(1)O`fGck!?(^>7(UFtM&48kJbqZxxbz-7Qo?1^ZD2E3Q`C0-&0^r$rthC92RfCUd5|Y6S$L=$CM7%7Gg3NgnZeGZ7fd^ ztUrUT&P0&%grtnCcnCFxT7~bL)D{{g}XZPv5 z{KPG-E0q^}nushI6V7A@(E4Oh@D+26B4|(y&9>!g73|Zc!ap3(R`VDZWe9g7zGebI z^AxBsK0jS$twk&+RMk!A>asB=ayLx`WV9R$+}eujg2_jK!4-`rcnXi@5QBxzX1MFP z$!VwFh$Rq$-8JTPtOG>SAKJ&ug=zy^s0y4?Nsk4<_IpY?S*wctc4MCDdVc=TeZH|^ zELFqr^xgg~`=J2!RdyIOC2~>hlgIJ-HG=SRzorJJzh=1-fYVw>L0rL8K>j61Je7>=eT&JkW&-TuWcqCTYf zYtg*5;XKMg?AT4MxFM_vK7}&ZYuZQ`e*MTGuQtIo&u`VEsyFNXD<~eraGrPJqfnD& zLBe0AU+&2ce4ahk_!3PndKqZHY)_-Zom-JNx!*KnR@bWnZ@rPd>!M3zkBJdfldk{9OeT*nnPjiWE z3WQO8-fEk7@xN43>J-w*@~djs8Q2DBjK9JhWg~ZmQtqc`o~8bBXBrY_z!GBP40nLM z%}s2#rpcYT(AG7JEHT8c>osbeP1);l>#HG5i2~S_iR4rB!yqxSnK77{e1E@%WN_?c}bm-sIjKE7z(~8{)qA{X^gsvNhWHqnH)8u&nsnCzo!;gVj$F=ALb_! z6Hlr~y>4~}YxBRB7cWA6fqE#;H_m0BPw**R8@}zFi|j)`jokQ5?Pw`xk4zH?sXrR= z;dA|v8A2y;79dz1aB%Z1Eo1ln=8_r&fXn9|Dw{kIj?|Vw>pr5^eIAf%BAXw-8E`ZQ z+}z@$>Hu5@o2Q(L47fvfdcc}?Q;o{77hGDfvDjkK=?EP?f_3#tU4**EEsj;CJ~Z^nw2he71ZEl8GA{j)V~|6o^;M5UjN@3Sc8_dUDbG+d3FTqWOGM~~y; zDEy+n-oEh=*BTbD>Zm0a&o@O(B;mj2q_;jh8d@ESZf%ei+u+RWF}(|smlUk-Zzj>a zkTgnZ%e2A1Z-1X?@InZ-6WE9-!Y6^pM_)zLTK)Y&+E5ikYp-#|sF=)6Bt&tlmw(`T z<0(qweQ3h6e3udJlXQu4=LBtp=<1aXjx>U$jy}r=CGGeEg)yujN6NpV>KNgmbr%i5 zw!|Rat%N>1l5{}03z)7gB}%l{iH0HlwP)lc*>4z>Wg5E_(><1Ro2`WaI}JuNA*9vS zW$|=$`%*x+vrBEP2DUoqI}mDG(Q3W1|5XJp4Ph_0sh8AF{N=0Vj|XmU^Rm!i?y?F9 z&C8lFj1`lv`@$lhx04T0w6$lje#xag%CWcXh8JGSQxgqYq6mu2{S%++TaurqAy2Iq zLiV3KOS(KQpR9U5x(YT6RYE>Z4T=i}nSJRbrB*Yg{md1kY=uAJo`|>U+F)wm(D#@; z5l8dK78LcVaED|fu(EgR_Y`UgoZ%vPCW!4g$*}}3O(9ishLLO&M3NlXZx3+Du@8ysU4J)*$^RxQDA05+ya{~ zCFjFuv6>~AT*n8kFQTIVw2eNZvUfz?e5I7ddOZ-DLjdv1>3Ln!QrJ}(udk(}248@n zr|=fv{pZZ45QiEO6V6Q$gKxlnA~%*B?re2kA@h=pttQ5DR`VeG&Di!M!B(j@MNSb^ zru#`6B4<9^rY$s_8Y!pUXa7Bc%Nn9Lixg`ty!50u>NUwixW9M{i;ZjL;;ERdPyc$Sht1>J`RH~F3CdR;vUA!H-;6NzdAMCe!Mn9z70)i5t*Xu&7^ z1ra`mIMc{y4AL)}$6KCK+S5fonN+}r5IhIgaq-fD57mY&8;Y8qv#1MLdwnC%% z{2fB!HuC7P?W7}j=M(+;%?jBSHmf84#mM(0UWm|WJWw>gwuj!j#y&sU^;7&5A2v!@ zu}|^a)L+3e1~;O`f;-%FCerqMv*@_r{8RYwV&(E1#3OOj%(mu#_81l7p@hv?_=EyJ zk?Ok%_p_>_djM_lf+S(xAC*d;cp2Cy#Wh;dZJ_;18Bv10;h#xX{j3M<0Ac7BUl9Zb1c^PnfeU!lW)ra zi3>P2LCop(J}Mj|wLe@v?Jz+Dp`(60hOm@Zx_glpXd$R|#Y;a^HB>aiELAo9ir`|J zr6d%Ik>1Z};gfa}<0cbF*s~FY zO|i*8@i-KLwyVDAr~F(hxga3lg0pd+Tb`gCE+FqJV1@+9wyhI@9uOqjTHIdmodeL} zeY2bXtUx-MQ~o__@4+Gw#ZfEPtMGvQIR9wyvJ3~D>@d%6Y%y`F0=YZ5o!otoCNYMuTZk_= z4A(F1GC#A}>fX)RfR$6(#38j=^pkpXq?=OEphGqxJ zyZGYMd8$Kf&eS>W0e3=0rde2Chk zsDa<;u#H=R-OcQP@0L>qZxHd0>tqqwZ*NR8|gZT z+&^w-H8z$^fJcR3-eOPk8Ojs((F$fPp-$2YWx<<()e_DXhGxy_gJVsi^DzlM&hXrC zMr=td;~g48i~fzzytRB<~_Sx(&_2*gC!y4^)@2yNJ8CZ~UP2d~CQ(gX>_HZm^ znFtuq!TA+N$bkKw(VURsil}Mf_?!{Mc3?Y|l8w-900!^*-jgbiIs{wNzgC6|}VPE#7qtsBHgnU&<#JFWYeK3qumvSf0>Leg}NR zQWXjV9CfJAe*1>Xej;tiJymq`B*Kru1(U|ia#`Tq3$>zb+qTz2 zYmU7?ZLf-vjHi1Gw9HEw(626`ahEX^s1L;eRfPQX-t1+wELyOL6%zGCr~HWs!~FXO zAC-W#;LH121~UV!IIj4>M%Z&0DYPyHc`g_pl{N5Aw5BofzL|!zn14*EeYLKbPR`Ic zbf<2&*DP}gjY&sjQd3z`FM1-8t~dOOpLJs=44ALr8jTRbhddmv@!vm^!C@z zvb(L#Mk<q!5AM^n(?Ix+SV z%7P4KuBhjWjw(#?a36b(DPl&t0qPebO#B^Z_`YPXu2cK0v|hcVEGdI34WdD8RgqsD z*Oqy?M#;NxO$}7R98f;A^c3QFJ+Q424K?j$Pag5tz`5{+^iOgBqPC?C;~dmHm^9s%1eR-qInh0F%@y;dFRnzbkT)S4u3$a-pUM7+RSz8k3Hf73$H zQ|Kr0#Moeyg+3if2XA$1E*dPvYzH9__hs|+$a@vTU0kH7w!+=hXg{mJzmFZ>qVR-< zO*_)vf`GyoheFstKbIhIX2NwoOCfK9cBk0Awdk%MXK69~ZQC$T7PB48dar9vHnLrS z4buASMSWK>-Q~R)nXG(MT9Z}lMq@M-LTd@r03B!o4uW>Ha+Y3G-`DHD)X%7&CLIXf zj&?pIP_t`#tdnkSP)A&Uk1-reon$+_G>uFJtBx;DRFrhBAtnonsks4yd2rfTh!Vg{ zRj`tmKu*7Nj!W-DfMp1LE4-TP1cN|AfZuxjzt(20+his&S1x8lEnFp=5KN{6Xa? ztL7*mWl79r67KldNwPM=T3L{~xpNmzD-KA$TKla46}^;XBz1CZtwr(gI$G~qFfsCu zK!SQ06SjojEXQ>=G8L(&$TH${etW_^ba@X6>kFR}O762TspP>toYv9?t{{t81Er}Y zgPvTd1pYd8vUHZ3VWy2mTm}~aIW1Km^lbL9+uIyK4p7#GQ?eY+F`!-Bq^-xI9*CqW zQOUR{j1N=Q?Y6}clER|^y`T3|r>XkEU~RSa9{+cY%7pr;gz|id9rt5jc4iG#)l#=;nNlN5A-pYD~1R=_Z%-y2FrkQrb^iic!8r?WCF zAu`mF=uSGlYR;i zD4{kqn}Li;pxWM!cG0$kkT6-v+mZ??zoYMQr4*9L&e5uhoJi^VIWuF#R~Ym_kS{@$gTQ}FJ}l|PHM=PN z@%A?fw>MINc}fQhD&;XJi`8lHLh%vukD#({idf?gf&GHsq#vbPtOQRg8o#4id4#EZ z_&3<6SvgA-QJafnBiRUZuTYuv0~v3r%99n0>ui4Vr>(UXwcOE2r8c$yG?GD8Of<(< z8Xa#ZAwQ%@nnOX>-yJhF=K0W#k3(gABq86EhroU-Y6r!HG`f)TzAcQtHu$P18n4w9 z_zF2{Bvhm8-479!R79Okm;Bcy5?6!e?%W>zw(mH^}ZMLb-%NXH)EmL=UvLWxoVuD|#>&5?2@&Y(%D`%yfFqBV!hw zo(s@0PQaQgv?1%1am^Wi+*3Vs2qTmG@(a!?8Lm zK_cK6%(rU^Mz;Cezs$Arm(YIZ5^;i#D=(#PDf{|AOK? zol{tCRM)G_;?_hvrY#$3zORj?#Z_yF(=U9j_LS%9FAp9O~ z#qes}Ev6+HNvwHYtF7pcWiFX{4O+Jt`RP0u^!7zN^9Ym z8IJ<0b>o?ApO>mFwL;i{E5BM&RKbz(oW56TmH%(N- z@6Q~yHTb4M!MB?{#ZWmFa}6{jUqu9R#t^iL77r>%y+?3Q5uZBw8V%#~IGJLz8h_)U zh^HhzJ`7>>g^hEAf_MNtS>Gz_N(~e+j}xt5{ZPh(O_< z@fy{sY<<1>6PdlG?uJa`(bK+9t?We?F%a2M&Qn(}ly&(OD}Ia1KoW{Yc9uj&Z8_w2 zA6jf$KHw=o0~CnFY57d+z4eGzqrq&cGuAwI;}m)kuSy%QPW z;ZUfw1PL6LD>GmAX;b@EnX|>1)0xw>ML|SI|nCM8V?ZM4WKQ+Tv@)c_EMzyy}$5 zk?|J=DVmJ;{9S9xkdrF_MbJR>egX;}z$8^A3(y@t0?X1HumC>D=f{4Q$pamgk#Wlk zeTgnxKz~=H%h1U>CC;0;tH+ys?o}F|UDi&#h=vYL(njY?a|@ypV7mKYMHv>&LAI$jHzhK0|-|T>k=0J4T8NY zgtYLXv(T9l7m{-lHYh=!g}R`VH95PDm%ql1N0{dE;tdd6>}PA`xY+j=A0y#q>yPr= z9;RNpy-h}wXow<&`JO^=Y7Co`sIZZmqhzE$YCtsFTft9mNZgBkM-qB4+K5;MP>dG~ zvsOU=L(zJJrzariMcV?)-Z<{G3hxK7_L9bSZNG~hX~zm@mzoWRe2-~1Y%6=AI5~KJ zp}KSY{4*PqaW%~UH<&aA^@f>7<{0{f*6g)$lth8zz|+kj9HXeOPeVW zhEFB6!%!ES7(P4-cF89Hcp13|9L|&4?niEe-;BbK)CSM+ni(72x20e z1+99BZh=Kt{vQY>b56KyIlO*|#1=}3oQ7|;xiyWgrax7lTvvE-312s$>P#}9i`SaZ zx`f(Q<8+gPY~izBR&YC(P;~1TLcJ?-)^31zag)iX{PSD6@)w~xdUJ_*KeS*TLtvG_ zT9_s}RZT7>GvTz$vKG*z^kJ-f{IT(Z%lK)H!G#qAzY1CQ@Dfyc*ph$z`S7e18YZAp ziErucJRuYLegk8)kEzkUm;U6Qq-l+=W|`5XwQgETfI=i7B;yqiRI90!sK$gKeN31u z+lWhC*p%2UpHrz2yBwznIXxI}yUPecy9F%~Wxot_*T!yQcw?C-AlaH!7)(U(M4;{h zXkI{0VP4mtbl>vkvuG(%md+M-ff)MoxkW}cYHm~$QE>fm9Bpa1IfnCuzcLsq7vc{4 zAg7=;tr;k3W>n0E`;A`{jAUIr^`!B)s|&WGT; zbag$S%ke}M)PUY89$XJ6M@crV`Ycq>i zz2-~($<@CXzNs&cRE-seq!||`ba|UaHH(KQinte(pK?{{)<#`0gYRTQ2pRqs%pP0A zRkV!EmoqeoRbIXfr06T~IvHJd6Wj?9z1hS@K+LPEquG`Zg+v%sgU~W6W_w9VIer1I zHrF|Vzt#L?*w9y0ao=y0?z*jPvytT&#$TG%oUW3YJyN^YqZRp{CknHoi|LSURCv3y zVyvgXIiG*yf2vbo5n-a$(j60F^l~F#gKxnJ+A!XwLP}A?LmnLVl=RU~_Qu&^=%-~C zi}937_s2Ya?)XMh#mL%^F*h^~b6T9K3pa8icLGtV?e#s{CZ!4Vi8QpH@;uWoj;Rl} zSR71WryguuT5m_dJkF2g$F;GtTA#TX$ymoZJglWH{-UW0x|_NVcg{YoqwrluCo{ct$9@y#^BPC}gGP#EC%Z zC4Kpl>qlP1wLH6SDncdlU!pB|nL|h9VsdQ^!JLqC|mQ zKW6rx#+N81(sG6--*tl)dHzUCT&Qz$?y@usrHyb=6c@-hASWJt!#?x#a~wGDFZ2x*g~{)mlS(3g{M{b+~Kg*=Iy^ zP>6=3CSpF{<4U|HASp@tpy7WJa<$$Dr_5h8lZo z$Vt#!6bkjLJ3L337|^;Q0!CUx3M4nAYlOhy2sA2NI11l7Uyn=1Ys`=VgN<{OVI{Y^ z5cqqT=}AV@&5)9-)ttxCywgzFUyM%z!Y;i+I&i`x?K4z~yAwT26g6|HKw@i#E4E5k zTKho=r^R1a`=a$x6hLbEuQP6|KkZY-w!`o0yG&1LT(KO<-M8T*=XQUBhdgpJ z{#ftd{rZD8c7oJ@=(A4fH_x%4#Gv&;E3qN7a?~Rt2H!iZtQMGL7*V~NE3NW9=}%)( z-&m{eF7Z+1Oja|5s!IKND0+QcU=hsoa2kb=t&tz#X7yL}?iCRv|9F4t_Al}f)pNgC z4p`(SMggR`E*2pWXmCt4yuf)7vTj7Q`EBVT2-*IUK;_2S@VJUk;e4Rc|6AsR6>8Yb zf0$4LG&~FtHdkYt2^|W4`yH*;-aS?b6dsA|x^?}Jph++l)>{gs%@82BfA;#8_x@U% z4+WoPE)%NqvKA8!u@~0n`=gy24^DX}|%)UxzqWcaF0dsP^*fT*#aHEK|{Ue1=*71+Pw(S6!?$E#A zg*Po~Y86l@oHJ%<1HJz(G)YHp3X5}`)wt(Jg*BcyQ>9DJaQ&0Kb|CZ=fK)57A*>S< zEfPA@Wu ze+wHGjF2naD-XB5Um`V`H`@-KSepK)kr^;orR)D{EYgGyMuE{=^8bzb|1#!#mK7lg z9%_+-Vhv}46xcYs9-U)%XbLdR1>4&1aQf?IgkBW3)o=JRIIJ2~ ztG8bw#09NfJoxtjO}*4mn{qa#^$ZX=(qRS5|6iBD0A=(1v-QcMj&4yQ(khr>`pQAT zKbwh+I}gaES^RT1d`^JPzjKasNUwiBBvNr(vCxB#2ub)vm!Vo*(}R_1Gf<8$Gxp+2 z)K1bozPcLMgDKAzROdo-`Ml~L)nCF^Szj;j;jutm5&YKerT^38S24npl<4$|6M6(C zx{#KB8L3LEueYjEio4Lh!B6)hmu+V8w zn)wQnG4C|ZH8k;HWMNV-u}+%@i?CU7wSZT#vn<{nySgOwKG#V1Fm5~U;1}%(I)}y* z#(xREP3P`Rk`JaLc{BzhM_7NSo&kLy+!Lv;mBi3Z>Y@)DUHacITMlQ!dv27byUzhCyLI&h( z$+ur+a5ajfpwGIcu%Q&XjN&$`qAE_}4dbAz2F5@6NmEkLZUcD^W5!zEgZRiUCC7hz z)N)RDHzj_?r^};|;8hEuIC_(L(SW+`YlFflm|^@Lca0etvfTFri}Xun7)cZ8%H2OF z_Ej_|k3uB)9d}9=*cf1x_!ZMIy0qk9I1oDl}Fppdcor?->1_9=9GEIxwqhO z)FJMz{`Us4kqq+#&$*Y+8}b0kvZWKE<-&m}WY|vzv0?LJk9`(?*V>j5>zi z=08n5Na4k61Vk1(ecv9k>l8_9_DRxPZ#-IrtZoH7+^fwCPaX7c3(NFM6aNVldh`x_ zgF?~xA8`N0+<*9XAAK=bYhA!ilwA}w{y*M$+!w4t<3@j8riywX))I$Lh0hdCD+h~- zX+K}F1U}b1bw&B}sMFJkng((9eb@@3piqWIUA)-8%WFWm23+D$Or-QAmB)l#4PXU_ z2+nq?f4NN|x(ZQx*-ei3f;(BmT_kP5$&G7-H~2Sc;fJql@9%pL8rAWH^d-pISbL~WTG0AUN z+wpgIcMNH>(Mx@;t-u9C?;qxw78C&Fat!q!S25%xikS3-0` z?xrN?uUT|Uzr0LP)Y3L@50)un<7+Xxxw#>rwHLp$#Ajeb`FwS5EYmkNsY(*(N643? zactevoU})ABx+wBTT()ys`?E@bPN+E395Achwz5v&IQe>v4E2kUXqVU(fE;DG+wui zmWR%#CH(nmHsJkj8#8JCa0vTsMHk_#zyB5Z{2`WpY$`Z#|FXRt{@?Td+Y}p!yiccB zlGm2q{U&-_+ls54XOflC!;{Q^VK!3+s=t37&W_f9e884npnxTi=oO#6f0l9EYg@}w z*1q1Zf0=lhCi8fyN9 z20ikI{&oztFaB7HBCo#XzW#4;Or~8>y7k%8=vP+nf70fEw>TS}P5z|(9{?Od{7>Nw0SR7`TAJ zPk;O@jy>UA3?vfLfw)Ndk%GRn?TcU78|R#JvP7mVeeHGZweL63(b0w7Ke>y$8TFis zygw3&;JkBA!rq_R9SS=we+%ij{Nn3y_4Rjv5y8C9Klc<|a`Bmn@<2J#0Y6-frLVt= zlTN(=fB55HuEVT(^MGs+Ju~Vc3IKyipI@e9fsEX5TGxcLPd^X8f9dzbB3%4$nmu3>RH)WZz!;HjX~>I4pZ_nf&+7x8H=X9rRx^5AoTS z<`%s0{4en0FMowMUw;$HL{jFo&+Pqa9CN}k_|&I%m;3Com`&5M@q#)B7Jy3-48caR zI>I3`6hd!DCr&y2BE0a+mxdfhbB#A^W+NWD?V`^Gpa{vb6BHVWK9XRux%f@t|op<(0xbDU)WG4Rg z&tJfo4*2${HYCXT%bz`tJwLUZdylHZNTqP((P!YPr=OQUEnF}cPd#xDcG>OYK)xia zU$P!TM@uVCIQe^c_NTv;zkdF+d*ZQ&Z$snkS%7U}`FqRpVYKU8haG}@ zAG}L`&W-M;L%*k|N6tb&c|Z@M{_F<{`QIZ-S7V_&wcEpJK>q9 z?!$ZrM|(cgVFhYBRx~W`5NxL?P`=)SIe6J#_LP)^p9V_&tCWq z-dytbU=cMnRXF_6Z{WD&kHn&F7X|FX@j6cc7`TAJ^UwYSM;&t}KJoDnH&BhNNxdRn7Rm1M*y?jP4Z_r(F=?%F4#=EfFt{=s+ z6{`f{y6>**vBO6`;?j3>FFvCe;i)Hnh*QtF81v`PMrUWYL=!%8|IPT!=RX4{lfkn; z`2`+$b*UUgRO`%y$#f;{xpe zHys@Qv;5tcMDXj(d-+VhBSgMKcgynW#3H5OQJZ+b|E~NFJ1Dy_3Jof^PX4q0`*LZ^ zW4$A{{y`3d@urL|{&xZfYuBvB@!vfOzyIwYuwdbQ)YjGxiHhN7cgD;axb(`4ue4?|zNUu;+$Ey8bq|9hN#+^JahzKe{1 z_#->ZCYg~qBY_4in_%Ph*W8GkZ@hJoVQ}%~7viK-PC!{j89LfJaN;p1<4=G5v(%aO z$GDLVnEl30aj&r=E)~;=DC+k6d@gryM5xLxc9E>v3SQFVP=xJ{<_<7(PcN_(|hhF z^LTwj9abz`fxrLrRbh-1V0irD+t9#yb)-3ZGdKCvd3^vFOr5sayfukOi}AxJpTe1^ zo{e}sj{OhVPXeHPP=Na??yq+F_($>O1NI+eAh6cHwDJqjzkp+pIuVtXm1u5i#^?6k z8}~eL7aAHGh5!baUT_)ix$Axj!074e#R(@Lhl?(|Kv>pf3FKd;(Netln-_7$sb`_N zxlQDx=g*&qJMOs+yMOAF^33Cp{s5<+cqTsfiCxgpSdaM&=3~vO)%f?nUd8fd%kkll z?1UdY{W#{$n>(cLo7d{W+XD+sWRK8r001BWNkl{2Bn^x^({ZxpA#%J+@LG6=_U_w9G#vMXOr*{81=FOjjii&c)`}TYI`zx=awY41wfAc`xb@vUZ zEGsRD{XIEWpN6FbfWdme;5$d3fjvL9D}MOcohYfQ1)RL-MBnM00~jpcb|Id7`a#62 zYKGsR>paf^46eTRHk^9$cd>l=Dm?%5-{7X}FTv?&p8}Z1(9Rl;pk?hkoN?wQc=_+I z;E@0RFTC)Jmyk@Q@YsX5;?sNWA@`F4m@y2IM{)I%xgFQh)zXB0_df*x`uFQN@wg*! z<>lw1vaWVqy{>dRBa0&C4HV{2@F<3ZfQd)*F*6x)X!z7aBKZQS(8-?z8gSFVGPz-h z7RB5;r_zIxD3nqm$T4hl?v=x5eMUwut%>htod%^Eyn_bIVV&vF@i!`*%)*h#*av`s zVK5D6u*t77`^vht>u}srC*swAzluBWy#srEdUxr+xf@_^beW@s|C<}x5{t(0)1N$# zGfz7QJMFw95{U#}d-ZkPe)p|7{9A_(0StckhdNxTH0{X*A78PTRW<2tMSI`Z{m!zPseFzoQ#Uf3cUUHJ2?OB3-OVUeFTRdb}-^4 zap)l(Z@jhy+-#HiaQJr)!*w@YjneYcadRf{yZMd$cFqgAc@g|E<_17EJ*mvK(T)@EV59}Y5_>u9=jjA5PamSpFC!YKX?!5JCeD$EO!O5lu6}*L27!JE@ zrVmtyi}0_%{S8Nc_jGjk^kVVig?Qza*Tm^{!f{8qfO&rZ({&8=^kV-5567Cd>rhu$ zBhnwIpK=T?{NAaES5{-qvSm2^oJ+BI(E=QI%wdR^l>nNKcbB}4^Dn+uq(8p6&tACW zwrjBPLyVdkeuy4oe;;7?_cvVM{oF0NIT`pjFJIE&^ZqmW4eFQS?}2j5^aBS*i{8z> zcchMzpC`igI;^j=JBbvYw*qsF6NbSFr(S@kpZ*!X zdC*tzKQH|ebLPy%qYvDIMcZ#bJiv-Fj2Rnep7A|A_x!K$iCsR7SN{DP=FguaOtEcu z*a7Lj9z6H_i&(R!31^&sEW*WAUdtGG`2I(6%n9en8Ro9ruE7z99ptjZ$J!$dY9}jh zegz4HqSmuFsTk??L01~(Pke23zLt4uGQ*u9k}HR0F!w^n%Vk5kNuSp?MxF(pqlV;V z$nSV*Ad?=N145qhHQ{CQbaOe3#^M+(G5`!V)v%s!<@$iZiN~IT&+q#=JpB0msEk*_ z%$pOkwmHnHPf6ERt5)O0V@{I5siTiO3b)*N8$PqRh-mFM}*F?IczWu66sqoXk_Vp>$*KhiZ)js_`z z$TJQyti53bXQt^kI18w6p@6{^SKopwFFOb4Uv!qYX5_8;w&4Q#;SxN2{{uMjl=E@! zStsI(Yp(M42R;L6gf%}L62O40!oU3GANcy$_Q!KSeGz-?{&75T&yA>~ta|=gMpeB7hZ#3zx2Pj?dB^4>u)ZVKl=%g{HTO4@ArA! zeaE$fjtn|{{_>Z<k_#BT z{_5*^{He#V_vbz{m^;9?0Q-EtbDA41VJ94Y65f1c39i2W3jFS+--|pDqh&?{437QI z38<~B#TUM`53af5I_$pZC-L-8o|K#i`ANC>zGt5LF}`>91=#N^UlJ*jU;pwqIQ{I? z1TZKqD;0JF^CHBH<0zpsQI47l!U*BQi!Z^oS6z?K?ekgObo-6iW}8JLM&;zTGV4HQ z0yp2>+!Ng3fBT7-eL#VA6HR#UfbA!A=4LlP4+P7g_A$Xij!$lOscp>5e~+K5!LN(2 zbLzj}-yG)x6ug&Z`t)UA!imoq8DnrmO#p@8lK+k^zjtpPUsxz$z|G~;KfD+F?!T|+ zkjY74&h@~LVhx1_F!^4!mUwH`32D2JI!p|`IeMX?xSWo2H2 zawiwU4cA_e^DekrG(7LR?P~0`&*z3a*7CXl$E#svq3as-a2PhnamH}RH2Ft z!N2e?z0DT)>JonyM+i7v+Fu6gNIEw0E5p-zLk8y#FJyPI9849kX$daZT8Bi zLvJvVyT^32vstN|e-Gyn0yz1>3nTXa05I4(CWS5WD>oxxaN3!d;HV=H7WQ(Kc~kNu z{yARd+*33N7%-`=gbpEk81SJDBM%L@#a!6Vv_SZ_hYZU{+38o&Yw3&80Z8{kWFIAlD7q@ z;2gz0CfVhS7S4l4wnWZ>zO|(dZ@=@tus7yyyA4{_tiikQE|cg<=EnQb;zg(-`yuzd zCw7jYv$csI8Mt57QoR!GP;JpONw^=kFbu$|P zE6>J5#!vtlY>89KhJ2F`7@Tz6sUmwy*&(v3b6J?VYqn_i@&F zr(>60cfkoqorI5k?874h2A7)-e)9)0L>$z4FeVAoH6Le4X+2hSmNi2Hj8o)4Oun()>AzahB|Zn*h6 z9CqZPpxo?e`QQBfcT3-TTh3q~{n$t3>_uRMa|7AlWUR-F;xJ4DZ!diZt5&X(GZW8A zb@g>vv}hqHkLI%_SY}sO7v6evsYo#Gu;cbmUF453auA zJluZgLs9{m^>ZBJp*a5XzrPo$M4n@=zxqOa_vmkr4Hzhj1S7xt`Vu)`eDtF`p|i6~ zq+>aj7(gzJ?@%$URIfXCAKU3+B%Ocs9vo27{bduU?B~%U5CXHVe=&XEs(Z zTaNdat&}lIsoCwf-`4d>=YsgMeuvP!dJW!xe}&|;6JLz3$vRGd3ZMV*hj&0jL;a97 zzisVWytVXQBvWaT$EBZ_zU<06D(j4(661H*05BK|7%(5fBM;s(Op?fVJc-?M00VCL zo_X>kiiv!`i$DJjc!6zQM1E1e-f531zdPJ=< zi4-on^g7&p+x@aB|Lop-;J?0clpMq#duTnwU_|G)t|K)tfa6a*4^KSx9FG0&p}6|W z?~Tnc@Bjlk_bGQks_{GTEyLUIERzlJ$3DEH*yca4F2R}STq?lDyg9Qld(KSRq*8W) zy3ZG1aJn2MBvA=~H(q@c`yF_gpm?)qHcBG1mX|;BlxG08yl0wwhfvTzs?D~n1;tTuiC2S2IK&(s#PyXnqIOqJU zgwa6ohZDiF_gCV*Wh*de_DnfA?7sWQM~?0c0E4NuifyLa!E+wR7>7oLOTPB+reHeHwJ)|NO<@;+EU)$EsCprIjix%kkd_9f0p1c?fD6>s<#HSy=y{*>~yX-^VXr z{4K7%`g{o_xZ@`%EG#mmoOmDgE0sQQRUt=JV90VF=F49!l zJU{u^o!EJo54(0yRrGXqp9O%yB#pYD zm=D|nn-MS=$}nJH0s%oU3f^ZI;Wwz&1yg7Rmt($S`2^-*~j^ ze)((MdG{ldWI3eklKh!c4V-Vk@?ZO*grEUUr>c13@u!99aMTe8tl2ZyuOFchRxMkB8*aJ>kN@Du5;VdN?23Yw z%U9y2Tkpez4?QV$W)vnthJ(KLWgLI(VfYktN{lMY6aWSv_;{b9FC3X+Pz4iNahN-B zkk2mQo}SL>zy9A}ao~Pmmp%K_KY0SLz4itU`|l&L%O^fIDqwKu?f2m7>#vmOFS+nC zkrul7_U~iC{P{?wT^UA}AJrpx=;23j#wlmxTZbKtFCVZUPCelal$4eX0Sx^6bI)@; z2X=OJ;_!oy#Q**IuejyTn+2$(-(obtfcq3O7WmBTuf2hr?zj6UNy$=bqfoD6OK^87rfD=zS0S6uYb&*)&m}Hs1d+B#L?Elw;?2Y43 zJ|53J`3xR<;E}=J@Lyj)5Eou@zBvB%l?!4y_378^?A1f^#5>(6( z!*kEM40qi1FmAu;3MiV2GtRzL^0vARga60gSAa#8wqf7YkYaat_uAdAt{o`YiiHIN z79xl(Qi`C0AQob;x~s0;-QC@yf*{Sm_cJp%z<}=efB(PVx}59UYj{o-uHQ* zyGkfY*jm?oQ(QTJ4weJQ;KlQoSTNTK!>s#f3>c`$N%wBwL%WWa$X6h*fDCIxw_^E< z)iTn~Q*AMI@>o20_6j?;?UFVB>9ZFp?i9_MHh}$v5op=AB}^1l5-Q2j3+DU4$#o&7 zIE+E77L74$t~dVJeVth%9bFx) z4PJuo7TuL=K7bq7Z(;6&C0HM}4f?uzLPQ`*pr2(Aj2LQ#8r)lDWD%bQ7$^aQ?VGnq zP!=2GY(&#qKy6z=Iv_oY_-1 zAkqK^MCfL?PQpZ6CjkUlIL?|m1%}3kDgMXTE?$s_C+c%FHifN_!mV4cq2C4!_*r^- zcyQ;w_=aA)aT}`wz0uFI`!@?w5FJHYz=o&wC>(x#E=El|Q7%|!bmo8n0r@ITfbngyby+qu-ehbrP%ooss z0LZZ6Lt(1m1Z9kqM~|Xwk0HpB!yJLj7ol>Mit?aFM}EfouX1x3cuw`WF8nYz#0el zAIcOkC{(x*{8lW+w#~m{j@vwpwHt$(vu4OZtBXed_t7KAFvxNUDpszDr7IT0d(jdE zE?)`9sSYWCfu?w2T(EN(@9x100V}4 zk{P>u&p|j&oh{N>PqJ>v+*cQXIyyLg>KuAlSfgms!qOP`TOI;;Pk%5#GzBooDEv!d zhAm-X0vIsN%Wug%v~1h@n}ETd-}i%{25XmGxpHFfzQeMijfng#>OCq8$JJ zHt{iT+@OvSroumb#K!f@P_J=gATC<<845go@Cf||jlr!u_pxJ35I#kGmbtWA)rzSB z18uj6!Vvdw-A0c-!-cIkf6g?qaTU^&Hs55v2gwMp-=cf3p?LLHNn(2cE?g{-B0oo? zZR@769^4nxoaajbGXos!)~$tHIkV&J*^4-Q_{lS+vSXnuvgOVx=f}q8$T5>8)Uj#fdT83L5xzu4$-GFG zBdr?RwQY``-8-R1?P}sNeedQ?nb!{=JtgPbpg|pciHXI<3zu>5&&2zBcHlqq1~I(r6s_w5Pmfg^G8!bL0#@<-3E-4GS0BpN6T6nORe6(-o& z;g7xhv1;8)%JL$tOg9M z2aUj`^OsP#Xkom0^A<&mG4m=OHEPxn+oKiB0^#fJhnlsj$S1I&cU>2M$Mp0tFBe87azw!bJ*6qwn#fr+D=6ktohRyceRlYM&Jq9StYP z=?Dp4gPOIgqgBh6NJKJ@96Evv=P$~BtbK>J=-j0f>Nlzn1APNLdh{66r#fTnrtN6l zq#>HOY=*?dBwV|G9fuDb5rAQN;8HYi)jXYQJnLEb=Z?Vl0tR&M7&gQbHlqf?fEM2x zYkt0bc_fQSx#pW8Xhi@<+DyXWL4Dvn!$DNW1g`Vu$pt-Qb0yfIX)&Zy7%*Gz)T#5> zu+ATO^W;XyF8$?ww{2R9G8M{yB?%Z9;O3PZ=x#AgdL%ox2IK7cD`EjlvU4iHKS-vkV}=U`tIx=O&41WyvjKYommEQGF|+QNF2qmY|_s$Cs* z>ePghu^~>LJd2~p&Vb`Y0(hrxot5@Pe=__%J^bK0+Y{x>mcq+dZ(v}c5B7xgbb!0} z9^%cLcd(l<0(0CZLr*LOfyWOYW3aUyPM$m`mb!K8)WZ9B@34FCAt7+nUZ3DY-#*<@ zrhI8YM-LazoD-WU{=UP{Em5mxHGGVS#Fkj%$Z`=+OML}yUAcxf9s1ze^Os8Beifo^W@d^U+0D_ZV;kA1<}Xr6!`46_ z3=Sbnh73Cwx2f2&VJk+AnS`>XOJarpLR79%6G(^wboJov>4%AvW=i4e>AeV-&YeP^ zextq(7?5or9}jmAKg^iv4sOnAyHLCglg2eHkJ4VKN?|b0Zo1eF(L2n0(M;5?-%w?} zkR;6q-n@K33y|R=Od+;sjB)wgaWrYs6&!lO?hX^-kvD~YP1aEbED2_`un&Q5d0a>#Xt zJ)B*fAfPcfi`!p@M+3h_Vb<#LYv>gE_=lSsTS@CVaKL_Y zT|2-$M|OEmZd|>I(c`D$;ln2a(DxfeId43Y;^HvT!BuJqKHpU6ouv5Rld+SHb!(r1bls0Vw%eWDfT@U&Oo-DIpya;t2SVvhaXxtYXr|lt|(cy zH1x<>OiIGA7Z{Wd001BWNkl}QCqm-da_EqVa) zDnIeGw|5p`@XvqX|E*W>HvtA@pOVQ>89c#BZgABED{i94jv9;!6GyAmP|PsciT0gn z=V6G(jp`vWAwg_GIL|a|Twj{w`KeM;Gut!(gZ1mTV#P9dVZiktG*)tCXzA)`KN<>k z__4wmYlGKf%mgRd2ru{Zz=4B@K|8o_0tVdEB`0IUhHbDKV1tq+i(&hg;IDIYMm2%1 zF5bO)hhBY$i?V?GKF&{r23rEUR97eA{ma)v2rH03uQXY7jmXLb68}^}&2XKM+4FqR zwR3v`K*Zt!z~->+vd`lhM#Vb;g(fYU%KQEC+dlNN93lIf!~3?RD6K+5!Z6m(3AJlh z6JWYf@iM~bynOCF1`ZvMw{PEz624LMmXfs)m#B+rQ(fS*H~^C!$G~-#lkAgZ&y*O2 z1q*xx=&-ctiWSQji4{_MC3_ZN@Nb=V(q6Y3FmSY+EFb{0)eH<2DRTmeLl&03;WT4P zicK*00-tl0YSz2CookHFe&wlvyDFRiiRgsYAj-5J4lmEo=lajMQYoqPI zZ9$v%t)&*+v27>DkFi7UytxsydIjoKtAlvNuOkpY{Pi8Ny3|s}J4U{p=%7FVZW`OSj4BT86 zV&1|f0wCqilN0)?>_*-L?DHH}vLSH!a;eX`F4n78OZp?cXGp%LeNgcd zMU_r2>ivlZV8AB$#?a+L66n}<0QMg|h5&ze3?66!T{g@KROso$$J0;NDeIwDm_OGU zYu1Hf#OO)i1`PD{k(ihOzokJ!kR};vQ&@n2a2aM8s1ybd?miF#^!1x}r1qg}2~~n@ z$RjxsFCIU|0uMg{3;J30ROQM7Um`!lX6z){uh@(n1aF@?%E0&JWGq=6Am>2>0ZGkV z^Q+aaBP#ZFYd47#4M|jo_H9OylEslgvL^RS&hyZzV_U3Tz6fTy*i((ik^M(N(1oP; zEnxwuP_?>bEkws9h^_ACt-B;Qchp!Ll_FFZi7}txFxeGeJ}bnEarWHl(3hAA4U1lB za;rooN_VsW?!nK>kprM2Jk6n#oVGI)9Swwsrw*?BZ=DXoEcWy>>IErGv2D|4uz^si zQaQ|=;eZ^ul*|qFLG}p!r{qvI2k+i{favHLVd5uK5}S~K<0sBalb$jUi(Xx@VBU1( zFGSFQK3b^(gXvCoh>89ph#2Mi0n6NBHLyRBoPgKQp224PRAG8B@T0rM0GvCiWCCVt zNu;twI(_0SdRPp@>o;%l>*ipz|D`Q-)M0}ejjIYU;O8z6T#JSCremVRcxeKpV%Z47 zMP{T%n4+E$p4`793J3xUY|7WG|C5Z>wM*CVOVMS zgc96Yu8uIzQv_SW){Ek!Th|T(S{TwJmEVV+9_Id;hE62|3*EW%$->shlY4hCdfZe| zKJ@6;5#`F2R;IFxC%V2gTUV_P6ErYv{W7#~*D}2YMHXQ2-+JQy(eI-M491KYk3GNd z)p{-Tp$uNTQX1#9<6%>Kg2 z+c$6F#PL&NGcv(`oHWl_z{boFp)T^+g#C5jfq-3RwXR($&887x`ugWf%R%IEMndhDna(`43!ZeAvmS4T4 zH2?#?j(`ChYXgUj$I;_wWPN35sHhA8un^_?Nh_Z1n^vN7jmj85ZmP&yzo{_bXPrEW z5o0EcqR?xh3r5V*&Z6PRByoA&q zM-QJs-vKt_QBI4-;$`?5F{;1Qg_rvx*gCjKaqhmr1x6-D+S$TcfWg0ThWJ4)`ELdc z_~_7TR088wjdeB$X$LlH%rK>Au)Yya9Xl=?;^=6VDodrDd>{Tnn%R9Vx{7kBOu5p^ zeG07$zkC5R3%#&bMvRnlT4vzU8C2H(#JN6(rz!QD@569-P zU<@8+`)$BLVJt!NLq<%*!9&LdeB^wVsk(r-j|wp8*~c2kPMj4hR{M#gVai+?wO6w0 zTI6m^v8&~LpAB0#hM>3Q2sCL_7qrC_aF-c2VcP{1Ake}2v1FN2%8k38K0dvBC;QZW z`;Q=ar8l~E`B~n2rrRRSclVX^wjFB?eJZLF6A=>?4a)&GqHJ5gdMUc}>@L8=)$ck28oT7ybV&jGtO2uY6r9~EC@Po`E|JRqH2^f&s zUZ_w(X@;ri1kMlLdUO?~F%J@$u3f!`{(S~Z9l^a+UmgK5c&Kw2 z)%!Y*lb)^~6fnWQ!@JR`eK(XUT@q_TR|#>92LXfL-TKOTuiv;<0-@DvE)uXDY$svO zs&#PpS|B+K#>U3*_3_1wX)`g{dJv{KO-4@UUnL~S9>;c!9fDV`hNq7^#*7(_L?q$X zty`EdW}@sfnb**|O)JDQAYGMVaQ*rX*jSGdLdLp{t0hYzZ8qKi^^KJE27k*4{1w1} zs+K&tbER7IxxZnfnIu2ue#rzowr+!y%Y5lq@=$|v><@nY6d`*b5~B(iE391K4D`j4 zh|O*S2JF)jTv*~8gelH*!MxJNUbB#=U_n`53GVgmZ7unxR3o%*-wpo0p5j=Nslq^I zlX&^uSqvL#FZn%G6ppp!ekqgCUk$*3AjX3Sk0s}I=nyLy7_0UcYWqGF8B45?bh+M> zyle(v?}Z>Cn`9|}-}%U%C%3GtYQTWxv$Y|<$jhu*S@Wr!yf6Eu`wt#TUT(u?4G|L& zDavXRuQ(RO=QhKy8j3U4Yw65SsxEKlC-s_d zid{Q)in5XQNJx-RI$M6_mzV_@h$RsL12*pIF~u9z(2Qt}W&$|nGaBrC37jAbA;yy_ z4|&3uPZH8a)c_IXqX)L2USl>L;sAZUR0;#C3=B+i!N*l=WXcr2QUe1*y&bafDh zjj|UzfYqy(z%)lboI80ebHIQ%r@lTO-F+Yo725WYg~LMJh^S3QOvA6MorE5b9ydiS ziFiW~A2kAqk4nuhQH5=M{tzy5R%RC@CMDzTo410V5p5)3P=``h6=2Y*+d!D-Fvqry zD^aps>8~MQ`ubSmA1r`@RiAFEoQ0w|eeAG2%xq%NGrB?3RsiDg?!_}{EQD>^0r&aN zV8OuS^0{+@W}drnMHYWH1T_y@A#lMa>b7-bAllMDIc?}-7GUtVELQ((AF2inMh>;X zxijbB~JMNi*N+IX;NZCH;C-FUot^F{!J1N#obKVT^a z^zDa;*a#`KOpQ#i|KLF|!+;8dwP7JBT(l6@tXeBfjxOChW9hOb!ct*TxPDD2$l&SK zrxzA@%$1^B8(_eKkN_E3B-^*{NXZT0I8)j;&UP$@j~XiM6?L$dw(md<80^~dn-uUr z)%^*p*9VKNSY7<`?^y5}8WOv|4slcvGT-3L>gCJ9hsYHBKg!NhU)SiNeU zT*t)mwunlImiII>HpAKT=P+Ujl~f5>zS19!>NUX4d$&Nj^P`84v3}DU*<7d#=1UhZ zp;y;F$dfk@)^A)Rst;}!mUu74T(@~*-_)v2E2Zs_Dp;8X&Vl_0WphFRgz{$QA7o`2 z{I5NY|JnI`4`488m@N(*JOdL0@dkTHwB_ZDT1_L z33SR!%K{Ak{t4iJ;|$V<9xq+KSgYMBpH%{i^!fMkS|YYQ+ye|6G*FoANmSANu@6?g z`b!{j(Aq#0%vBJPv61qe^I+&P6$Ufhri-O0=N9f4HgDW2nQFtV2SZm+A7{^;k^H#m zm@n`PSb`P}TVdDU9iUy1rX`VDAx=A>Np@2t|APnfkQg>xs8At{9Bqvr7Cn$BPwq^c zq46z=J>8{gY zZeor^BqAz03S&l$#kw^cL`7g@qf|2;K5`hgRo%w8W4DtTuE7**mCL{rpsk-;-jM%7XQFsQ2Y!KVnJ_HQJ?FD%K^atHXN^>ZgrW(*jpz=DT&A7F^Jtw>}UFh;K|X#ntAzGb(eQW&TK12RuW zjUno$Q2;B9@btk0k=Y(Pd;%|Cz82O#QS+oE<)(n)!ju!G1q{s0%>*zIn+G-|ROh_h zF9`ON-6Y_PxdaR|i;Ini!^COu@eLFlSNb47Z)!aS zE}T9sWAo`#q$m$+)T}BGTnfWMRj82RN%a9`Qt6H?z~FoD@*nzjni&T4iF#3jTdAV?X0{QUWISai3-ojZ51Ve>jPYu;1}TU!2T zQyB35Z{NCu@b}?TXtSt1d-fcH14Cp1Yu&CD946bNYSk);)dWH{zFAGcfR;qfS~f%Q z>Og6{YclcGfy7+IuUx!>vuDqXg%HPO=dU|)^7v^uyE=)&z{JEv00X_J$#mW zV&sV7DFwBOkqORUxPZ~a$KZ41XDkc!LzBjh@$BVuxK4LN*!qo_Jbe-zCr=cQQ8J0v zg{+rCw{E>U2wD?}f&~l6hGg1gX9@e_x$;_4ul+hHV97|E;i7=Z3P1jF9_#M`49G}5 zdh9edhXtTToAxU3UoF98I9H5v-Bbr=DcCfp)r2Hfss2g8tt*#d>*y*hPG%DF4 z6+%|}pi%Q?h#gQEM1yiTZfwym< z5HmQ}@HEvWj;^kb{G7+|p;qv8cZCV1|JrK0EWqH$dGP;jSMoOl25K-)VU|rf=cx}q zh?<*k)n+^8_iE5FApu*r{SLD9^X1KrBm1|YY(*ujHZD3kMPVR-fw1TlaCM$5jf=qp zEJVGOGhZ=`vl*uZDCaoAVTzs1)0;MgW)2t_7$80_4l7rz!N{?bB_D-q?iMYZWNb^I zd4FjI3|geNK4X|^OblMV3l|GFW(v^4O;=A3PoF-Q`JOCFt^;a?ff_LA*5zl}i&B{_ zV@Y5*B2oZ@ppXp$7;tW*MFVqatVh`+ySceE5Gq%zhNQ$e9N2dlb`CDcp4}X)gBGJ! zz4}tUu3EVY_LF8JDk@soqHO4>_e!eRYPnK+8hd%nL?NmlvH*jB+wAh)*Q*H_(D$Cf zzO*aVBo?VH%ZwC;IDGghx^?P_!bJ*8W1Nk6_7M0Q+Q8b4vq$wB)v(CN9R&*Hm)cxy z>7xl4j2S-~*-gyh>Ftdv4o+y(t~C~UFGPh>mEq#%EO~9C#*P$ke}#?$zwh0h4lq!k zBR^XW9`M3nzIZ7lErO#|Q4bqA6wWTw#PTR3Vv#n$z|yi0qGP{WzY;7UQHy&sTF_Dn zedqQ)*{8jI`%Xfj*^Cca7lg9q%gT441`IeRA?tz#FyJ|`?%}@)7-H*4G!@kqp}6Q{(+iF>4FfxhV2u7m7Xw{F`e3WK-r-XTw(JVHp*roJs% zvIN|{7oc4EvRQz^cRwNj+b{epfC0mYnZc@SXq*^vYHWmSa>oCH` z5ml>Hz=1s*k%wWvX#j(DYnGz@&#jRhA1@#P^W13jPGXW#HgojsK0p-W4E60hz)}X- z-*;K&fPsMlzC=aCW6@GLxhz1P+STE|bOCDBp%wf$dZZZw231p&$@TRS8y$_;Z{OqM zg)3sGL;#-8{oOxR>(2Mr!vf;+)oLXZFpcMI+kTo&}pKIiOv&-uQ+Z;$@bKYH|^i~#k{npHKes-8J( z{f$TxnpCUTDIg^<(=QQ2gK|9UG zk>7m#P#H4=%~H7KjZg)FUG}Gx^VzR`JF`(n(gg!f*9+xUIj2A>nWD)Blo4R2Uq@9W z0ytAvGg@FlMny#S>}2|5KVR2FPkAWxTvTJR-+OGCzaNO=8no*kRVi3qFbQ{;kqG5kMqf;3XN%kMShkvvMu`q?R2rXbR_ zs5)e&3xNR*0R+a_%GWaqiD%M6t|ZhFojaDpdMG7tsHOlniz)K0iYJzkxJW4Bx^AdL;yG#ZGQ@ zcS0*LdvEJZ*79y{FUUMpG6S25Eu@ys?~9!M1Ge9}MN)`x=4aDF?2qm};4zBgfG4;L zzI?2~ay_0y|Fw+p1FOKR@?-9ayZs})??Mw9rKq>~YfJX|dRvy2@QqiBDM7K>Q#(_I zbj{BD%!72plE&XUZ$0VXeq;Yk$d)88347et`gNP~C`WzL=HK-tuP~^$nC1MsMOi6@ z!~8_25eNUdVK8AcV?Wu}`MXI?OW^8{>JLmv(an%vOKROVq`!-=1O@n{eVCE6)&h&RO=MU7^&k^BW=ZXQ)V5h2bMTP)F zye8kYrntW4EH7f^JQb%wh$Dxyild}PiHx3!;#Y3o=MeGk?(1;V4&E_>pN&3;sc_&* zqc+^mY?BG_Yvt*#aO%F^%7joEs4~5A4l{+A=O^2aBV6lVOzZoxYA2zceKINmQDsmZ z5gVo?^ysIk`t!_p2(>#i$)Qgd@>`el3=-M91<}A3Cg_4eeA-$P6IG0$DFMq z<=v|-a`GmTa&E`+mn2u1pk-dWSHq``MMabT@GyVE5Pr{-A{|VcSF|YWRXWu#cgB5i zj>Oy}M|l*+a#^}oUfIm!Bym8*{Vu*qBQj&-CcWgc-WthdWCXwOO8D8Ly?~StkHPQ8 zYo`<(z{nIlT^BwuTkj#1Zn$)MIV|)wuZj1~mZoB)|4)`QzHu~SdzKs?adZEG3mkf^w95^@VzFm;`}@Olz2pR8 z{9K5}MUpX-eied69GxNL8cy>zij>Ff$V8mmX_35OzH4MzYq?L_Byw249;GSjndNdj zMUg?%TS17lO@gz_k1egc^pn7q9kHEF7;x5yhTtIgCXfC4C`T>j_u>G)<$Mhp8sSh( zd75^v)s%eAR)^n(U~k-OI0U`W{k?a6mn$+7Iz{&IHkjx*`D@UOJ5&ZIryWO){1q|0 zE}MJ|mT-cQ*I`jztsbKGu7+Wb?-fkC7I$|R+f28#q>bqi1&EY_;Rg&8KQxrk8t(0_ zDHxD{0?o*?C45ufnkk;SFM_J&mTmdUkU`i-%;{U(EG)2yo%gAJgonBv4FCyI)6P!ufD=AD4(F8do_I)C>zuk;u8qL^ z5xB1}NZ`p+-Fh3~5P}(^pY!{R!sG%)%@L(YxpTyxeN7GtbN3_4Xu;|Zlv4>RzRrRJ zJDDp`_jaiX$9|<7EAY{8^?gtzjh`F-o+BB|DUTE;NTp;FW`RxF<+_G(FyBV$ zv{(R~uyDwe(YySvH0Q3S>PLAq1WFIU>G&@!07kpl$m18|Yi;;IGQN+wD%ecIt&b-Q zJ-G-bm(a|0M&;WN@PT+V5$NLeZpWhMN9N)3iAA1eq{L*prSA{#2Yu%Rn^zrj(f8M8Gk!w>M|O{> z;Lbi73IB5H0i|LM-@@-` z!toGQPF#F^4wKyPh<><&J*(q z=B~*;=Dvrb;z`_Wy>RI<-7gwZ~dpmNWUpQ#y^J`nE$^soI-GuUR-Y zy#PF(!^kC2Ct_G~3WO>(v@bSM?ln6hQaiGZifIgfrS)~cJ)w`$Q6?ds7R8$b33Ty3 z;Q)4eYyyCbpmUfR=2*ZX-wm9c8oczvG5^@e-}J`5Cb=>cry9fgybc83`zZtszC#^e zM9M@s+h}$@h(5WZd_`Sud&ldzJP-2gKq6)jyBYm0Rb$uU0HZCF!|I(y9Rzn&ccW~x zrT2xs$fHqVp;pKRP*}CE2ET{^fUoPp_DFH`fG{9P`sFcE<5F}1ao(9V64fq)kl?hL zhQfi7A_(x9@&)UBC2ev;F!>Z{4BuxO$-eqtw*5Ox0QBh6DV{OYbIVsgKH5oO5lt!e zU26%6eYTy4ShGZ95{tOj_JvqX_&qbPklB9xd%yh=pKE5+1aHn{lFNQ&Ug2|20F{7= zlT{Rvy%|6(;#-0f_H4vWwvhFCc2CG{L%(K?3WP#})<582fItVol_FZ*{EZ6mn8lm% z>=|wP>^!^3VG}YL!cJkZZIUz~rRRH*Y*~5i@8`+ObQ?&3SQc~`wA|`}u2(ss4oWQ%j^w=_yHLM)sDf)pA%X0<*`(rd!~OH z-`!ao2n?;F*o5Px2zTUq{beZQsLcz(AuGRsqJ({PO6o@_cFOM>IHL%J(R?Q%<}*Y8 zdzryfUg&v-2C8!08fSyT=tTtpFZI@V6O@oVA1@4ss5#wN~)E^h0NpKXDA_ zS^9W{J$P(vr^Kez?*_RW2>e(#ork)U7z1=P*~2LC@bRi_bi)vHt4}4SErS_(xE<-8 zaZew3SdE*a1?#Lk`+x9@M9`zNUt3hTK5l-00QO1|h0v@#`3kP%(q5Yc z&LE{{_$j_z^oAx@`wUU^_(Fm1X_I?nFh&ycO`dLG!K{}h7@2h5Q-~n6 z`rY8cz%jk2_!eEVvvw`mVGVp|@g*}wC5y2?`+zhQ`XOArREh zTy$eicT@??x{5Z$DuP0|CgiuK`0`0CG(BIYpJE|UqUS9NmJZ8ZS-@Zh?lTzx;mMPsDw;xGlN-px}#&#(m2J`j=dsy=B+^lG*$JEzN-mn5otL+}JgtFaX$rO8bsYzv2whzYzuX^w zKt`Cz7mTX32j_M~=bkClaNNBL-}58a;_4xzNbP$FPCn<`_Y-?^878q>CY=s?M5Le9 zdVYqjdC$p<5wIC3G6$TS1eq3Mn~1M;y}P^30)gu<*I#v%*d0IcpY*|(q3F@Bc5WBZ z3t0Fd_@K&Yg^vEQK%_rfX=5vif!$%hn-X z5{+a9U}Xq?W|Zh(91N#F;}PTPq%xAx;2aH2`z5KyKD!VV=kvAQ0c$GNF3jac!1(sz zWc9Eg7O{?VZ<+<$kKPbuF5QSE`j8(wG>vjD@S34D$7=ZVglOr)j(Acb3%}K+zwyu1 zetblaB`^QIUV`w0!AsV&?0anhy{>-%9^hd-&}AcC6T;(v>(?%+Qk(A463mf`c)I-D z-)5uPO*LxAC}TV;tnK|Vi*j`=(hUoZad_w^OkmU{hxbJvtJTEcDb{5IlfZ$`A4A~_ z6kUIQl2O=Huvj{+J_lsy^QzZpv$X4--zWD6yKbQ7r@&h^}gNz01B?^t_(a=x2n-pf-o!Cxf(+VdZfiN&T#Q!koBlq4rsMTHR~Z&9@!AxNpML z@`utG6m#;>GHz~-Jg=`eS-=-$A!3&YJtZy@J9ABZ?{H~`eYs54fdcV*Xf54haLVx( zVUITF$!x^zjg(tMSu|Z;>G6~GLVFkpCP^FpFC1sm_&x>%->FPl4|p(P8zLeiLvA&Q zs!tp;HZelO@o`DcboZ|w03cXlOcuN2j^)7|4EY`*IoD52VMkI$126fbQP;IvJ-_FM zP6!EjC6C>5xrjm8kUua7q09t0_7rlnlRL7xnMijM&c*LQ(U&#*E+`0OIne@Gy5F0> zwC%6bV{&>zJ5~E|PQyY)&vB9o=$f$1))QFu1rt}SrUOr)6afIdrQMvkRw3S`bKfFQ zKnrb2>MjONA>kp~13g&_ehi}~(*JRss&V)wQ!T|I|5^$y6O`y~#`S#fcj5$BWZ2c^hbhaZ`3p`1hP*P_7^ zPKzxLgg$ps&%g^^V!U1479>$*u}Z>V+@Id>TlRLtxqb1>cPyr)5|m`k3r5uf9Bsr=iS4qngNp?| z8Z~|H%)S7&#pyw?a3TTS5)O{pY9}_ELqo*-!YTLwIVsIRA=@1XGzi@4@3zpYD$;N# zqKen}WlE(5@5y8$ejGm@wS}ha@jt7-~+wBAMC-K#N2XE#og(Zl%>-n^0Z$IoG*&BeLv*J&mvB1zI# z!SGTsZ9ZDB!Ra!<-Iz$G)nnmM$~aL{UbBM~j>Y(fuFgiABJ-C;q9O4cdBW_NR{-pT zT96UVV#MG_atg8Y@ti=Qto#!|In$H7(h#5Ba`M{?Z*P$=T-qZ9FQYw`vzL&_j^81d zm^@Oq;4_lO$5-1O`|4b11yfX_N(Trd;SiGL^PKAdM|H=>?cU*I)28n$d%fDt4@q@) z1dKUHUlXhe6x6-4nr=k>xYB`O2c_X;yC8!wpqcBD#FcxU-6}=aPz1KP9p$ReMC%-E zM|t1;$=}Gzhdp%U11H)xvyHgaq4Qx&I9!V*8pV;{HeLz)U6N;_x8V{rqG^_BO{e~^ ziLhuI>J1y~=}PB~7V>&w#q?q;>BM3$(DeB5cOc1x`QSF&c?#X0?S&f=;qxdUr}#IJ z0Sg@c$qim5USQmx$*oqxfroHYHXsN-Qa#Sr?QoN{TL9#Re>&7wuIIKcpKjp6v?hqc z@24JPT;cGS=EdPD+&o#Vx6ARjp2dAP%{+EKe>b@gueKTisd+7%iDT zXAhp#8VMr~&*fwx+K^b}V~tt*cvhC#o3?N6x@S{^w4yCG^@YPiiZ`wu`0OcMzYr!0 zpz$#B#0U?xExC7+ctC~#R^PzKxSMHc4RVq@R5{t?M4JG7A32J-P~|*eQ2TL7Q6S{a z8wMD5vGF7t!t>{&Sm7cc8M(YJu93bQ(B)YUtc@xtC-8`Iz68ldXC%QaTcQtZ@u4+9 zd;&I~B_?I0&`s!m@2qov5b09t^Q9;+FfTN^r^f=x6;{tS_3H+wY7VLzQJ|zb?&D&g z8z5T1Kk}nBSm9M`pkpEdC3`Rf*1-~b33-2#+Q(WtkV96&aV2~2uBlw&?jmaV+QL+g zA(^y~JDW*pgfOf_z#55JKd38VDUPK$%ApX5>@Ic5nh7b;4I$QI@L}3Hs?E(H>2v`P zAnfyFq3LXip5}4#Chv!uAaRRdnV5HeEB-Z!gN0}! zMhA6@a47*W&#s;&B)}c?R7>5lk2%neq_@$VST;Fr%YR{Y^yY`Z*qMgM{~04*x2Gw6 z(-nMpS>943aB5gDFD#77fHXda+Y9cI2*KQqhM( z@3X!9pvlHOF0Uvrd9>Jcs*sFDuUA&2rJ}H$XSdRxxLO=|huvit%m&ucd-0O?ZHSis z69&TD&QqE~$df$7?d}_a>idGA62Ad2bITA}k3#M&{=KGWl9b`OJew>fzo) zt;rFIfH4-^ZaH{;V7-vv^FA_YZ7zy3dY6klwr5coX=4Vc;+8T!)dS=ij+VJ#OG``3 zpDKzCC6%WZ)ZPI*`KCXsBg6W{Lav;bM`AA{z46E$w+AAwbxJG}An>}7M8-s-h;Ii1 zZv<62EO^j5=`7=NqIivPuxb<-a2bZ<&io2 z`F^pT#uC}z9Z%JZ~=ih>oljzkn%KSbqHc~&MHtWCI zozi6n{UPixOfT&}g8f07&U@bbhuQw-?YYd`FAbqIaW%oUN?vnhFd*;H&z4u#ky^1e@Z~oQ+!ycsyBbO0s@!l zCo6P$Oog-oi1~dGS;H~vM5Roc7O4kINTdB$t68}{WzVQtJ^7|3Jka$F^@*3lzdOJnWnG^rs2eVT~?W%$F-1zmkymgtZ$bvV3bNwMIH zciO6zJ^qs}P!mBj23MNdZ~P6%`A1Sej`oN#Uc@IUIpxI zEBFHl3}Xg|)$I;W!PDzhWgl=&Mzwy3R@OOaBVhfW-s>BtWY^vSAcbfJlmq)D-&fJG zs9Hds;+mOyKq)YG#T3ij8yhUgzL9(sYYhAlpI&8gCHfUQRP!D`|7xO7C0M67 z6Q>I~E8hsxWAuC^eJnE^EzkrZbp6WE3kRna19qpMZ$=y&>$(4vMKOyL!S7%xW;|aZ z_sw_Ouv>B#t3P*FM-lJZ-fn&_CIACD_tGnwu#K~s;PN-gCRLCTN#^w9fbeM7stkVR z$sH8{7&qL1%I|nhKQuh7{!`duB)vb837M4lW6(n&vs#HW!yMD;;?ZI=^&C*R;aD=7 ze5*(+PjE1iEfo4C%i(a(McC(-z;>Y_Wab^VM;otxL7Bx`79R+JCp8K}Ll+F#Hik3E z2$>Q!FCR3lL0N3`3E0nPf$DpDtM|KkDLbo-(G;(Q9kV1)ZkF3Z!@|Q8JJ#oUCVFwr z4Juzlw1`n4;%VZHmBsZBy5t7V8(EEP96RHtRc5t+W~-(tMNkbmlI0$jiOUpTlWw*5 z0lDr3ESm5tM=RaZazr-sH43-BQze>`(Fb#1igJA4VZLp4M&13DWN}nlf1!*-yx3Ac zg!MX1r8By`R5f4gqH@>;0fAi9Uc-@~=(^On;aQ1xWn?nj%IaEvX$bo@SBxr);ybL#dzi{BhXc8GT-hA?A}MSfh6HS}`5Wx05_su5+C z_@z)4m#Q2owt|(_xS7`YS%loPc|peWNP0cRb0XC8HB51~!FRn1qsig|cRRc5OKE&m z)SL_qK-4=e@$d;LmY6bhpYLdOZPWS>r;-A}@P~+DWlz1a7N#Cy1Wy!Nns7%TISkqH zdvc|Y(cJxNS6V0}umX!>#fLsX@3=?obzU(1OQ5P|LWkGo0Zl5OwZCwy8oc-Y3GN*> zIRdWjv%a~SG(S!LoxQb@L$5zZ`SOHvXo8R7V&}b+<-EwxIOa)CX%$tDmOH6g$6w;R|YM7KHGhcbJvz;5+M<| z+}k`caNOOp6d8d2NTQX zV<9(%kj{UiK$OmI#+wr8g@*1G3o!pr!uJb(;Qaye6<+l>!no%ab*WxuJxUnWs^ajhntbs|ZEMVm`jvCHe^RmKHrn=ZK)wuN^ z@1EoT)HQub{{T!|749Vp+8otyz4!tFASty4Js9<)L3@xWOm^VzBPP1^gci%`E{tp_ zt%LX|^Py&hh((J(UqYN_DuSp|rd@c<6|vdn5Kc-j?S78n8s`Cip&o{1{0L!i_po1= zbae{w_UL7r^t^^kDD3ukeLMez#p`q5d5qFN1#rPk_Xr%O1wzXm=tP>T^MQ3Su%IvY zUwH^oQBNCPHo^A@(~!Gpks5(bW1N%qOtab5_L?p}XyO@Dyde^CWOYV#Rn?)+8W*yb zPqR-QjXHhR@fA#x``v_(ODuM3tqtZ6h5=$kjCz~JSCN7xoiLy6JcG?W_p^f+1;Z&Z zu~TIEye1;q<-0QeaT`^)7k-1;x~vJ7tzKtXKFo4!(@S^b=XUco9wnouN9iApT2Qlk z17=$f`AoQLT=&O|pYKf;2V}2e_GgNu=@bnLXQW3ek=2JAsuo7u_Kx$~?o>}iBBAKE zsLe)@3vuQY3JwRa{K0`ox0))Ub)wo=;}sHMbBsi3tdUT!cqdcsls~vW6`{?ka>t4p z$1aPfsi>)x2$U+wvA8mhrL#N2(6D;SLvk1#a~)p%RffF;SI98ZEuJj!;==WHW}PoW zc~!qxK-uAkEFmC*b|fS_5}PWI%4bUYhJazN?AG&k<$YX4f%%A&Shr6f5EHIK_QOK4 zDSLjo;oe@N18m(jmJ3F0nl8auYgtEJKTUNi)W1o3$L=o}6 z%zCOM{6N5wK5=(!tkLnJ4c5Y^d{G6!j4;h9ICN-K zzm=)|utM6#?3oSZ{OC43U~@lt+#!dv&X;U{(z3D=z;g|DHfFDGasbenq=5keYkIr? zdnT>*t0m-PJ8nJNW9o&)^cKi*HvBs&98(cOIPTD*D_9|V-3EpU#scJh$$}!nK zYqG_HZlK7p-kZm2vM3#mU8ja04oiKJ-SY<(v))E38oT@BpTt?9qO`X`4RZ74H4-kR z=%7o5eg`2E+E#qoOH+{hzyPOPx`5joE6li=j6xOZoLm{2iNr017$7OvuhUAc5x+II zjQGqZj&SCCAKT7iafZLJo*__z0zs2HT~r^!{`@_R1Sk|_r;OUD8Ukb7Jl+eW$qjK9 zJ=ABQ|GF(^oKTOn7dR|?y1gO5&d!aNO_zA>`+;{~DUrGR8+4Esjt_Q@7k_m>eJT(S z#}w6}dBWvxC>)wnxt;C%X0YzeX==T9m|vuJwul;i=S=&NSst6lKHS$3404=IDSVJA{TCKMZNE^j-z1Yd zF{`7VP0RN1Y~vK4T-04!=XO|+JMo?rUfIh^?Co&eeXE-V@w|mi<{+!{lUMBEx+?KVpB>owdN(FBj zUm%#R`dhN(0|yDEnD38MQ@ZL|Q;t2}B5EWavIa9j2VwUL3ze*%y3!f5^`XbY4y&(@ z`N7pv5rt)=>y-OrBLu)@AfZqswVaAVEE`Gwe5M^2K7Q?oB?6>IfpI=@vuiJdW@ja4 z`~IZgv@w5_o~2HeXoJtA5?qP*q_k20sDF|V_bSy=N(fGwhjIz%!d;s7K~Z#;zo{AS zI2^vzmql966814lX*BDdEKx(cMWb%@y?IN_kGLN!BKOmlv?>nzz8J4}wADlEy^~=Y zQcg+zD%fx3A+JA!?y2xn+`|Q$l=sc~n|=&oV@_+}&CF<7%>CE!zCg~1fFB5}wmX|} zZ9siHSaB@d$W|im6VqZ=AE%@EpR#YsTv9HZaJ3b6KMk;(w=J%Re=sTLv zjPc2F@wgAC_{tglbgog%wLqXyG8FvoOGVtIv7i?&5vx(+F%1L9n~@lKr0A~!#U6Wl zOo!!9yHFkM$BtdA2dqwUIb__dmq5ruSMEd?nstW#^)Ow1(krE31$Ueuk_nIfDcXiu z6~;6NUW8f>nR;KY0B~JAB2Bn)gVEW*QDbAJ>D-s%TLN{I9L}9Tmlhh1^C{L=m7$%0aQ8Z=Y8x=-;sLa5{9^DB*0i87@<{m2pH1B3Q;jm=BPiTx^OqqQKjB>9 z%+UVt3x`7npnk(RUrI{_R1Gv#vkN`b!snW_?!LRCG*fg0iY#ql>hZ*SkYvUByOesX zeg9~?L&G34jUN|xa(SU|V(2)m(_)3%tbeWnYAj*jW5fO8es|R7%_dAVZHY+1A1ePK zAqyY&bjPKnu4Fz_GU9lLL&19WVg_Ox+T_8zF`UddVY0p)EX@>GaIH73n|gt%`suIV7jO zZEv2py1lM>$7OsQ_vTG*YxFe%gS=!l!$#)7q9TEvl$`7sB^}HL~)rV|fyl8PX?!dNy=>7!`-j-nNc4MMJvi zCpQ+h*&f3-_h)Vgh8x3K+Hf!+5g5f%;?6CO2EZP%T&nKzJOVGWWPf_-{WUY&YJ!wn z6pm5CBeu6r*Xl*dYAyO0BH(^tar)h*v2YMg5|04t2!r6@{=IzKGK+YOF8S zS#{6Qy`!o#zx;Hkd8u_ji7JwMqOM*{QuGSauZu({ViGdb?IT-|YE|QE*YUWt3G@2-#cOHM8ZJI%JfZu2 z8(+8Bj>}gNxCBHDwh@1S={<2Pi!_A(TA3;p8o#x268{b!Y+GPESV;dT0m_FbR@eZ3 zdy@I;jX%Uy!;{tm&7CsO=FbV2^k$#a%qvFAJ4a4;Xf0df{+a+S==96{+&3HZH=_QI z2&EWD8lo%mHJ||kQplS@4>^SDk}U$`RLRKI8zwHIxOVTKFOLb|#c;Tu<6K{ZOO}(J z2#oV0EUutA65mJ0?S-n`s2Q~@3(i?IMOWpkXs6DG&mxfHu9v(`E1r_mBvf3diwj5t zIuO`bQTx5Khd2jLUPnJRpK2h-a`B52GBBy3(wY;VXse!_5E2qyAte9Nxo9W zFLQq@sj+q(pD6tTs5e!$=3yNdUTX+Hpj6T^aF7)55|EXXcDg@h9Xd1v3S`{Yt2TPe zBh0<|=(ucPfme&)I%(S7Iq*X|o!=L1=!qdYmK$8{q>q9|$W}S%{e=yUElXJXShUCC zU_-asVaN8vspe_^MN2}x{MJfDk-0@2S5jwxGa1=M%X2O+x)GZkpztDC=O zu3Z$^Q|mr{mA|3iPTFgkEPBIP(wm^rZc}=4`qr?+&3rMUK-4*IVx)a+Ox(S( zCi_n8Uj<@|bwRRKREVd|VW7#xUv01FrU3vVivPhbF-1sNzV@FW?u8Tma zgbYL#Q>H=p?n@Pgtj(J6<`^4d!(YcAUS>|KY z)0gc0&qPMb!~HQA>4bQy{VLOBy;cyghiznS_vse6$@QVvKPHq#?tKvxxg$HX@pwig zvI{`rI*UqQA8%P2nxkU;9vAKwCq>hgAa*EooCv&xSA@w1-%3|p64M5~zC{gK&FipI zpuvLzRTFYrWq+)Zl2A;uG zc^5OeU9z(6{2hR>+}Po&7(99&ATg)mjT7KOP(=a~ayxLoTk*M2FrJ{K@a`Wbr61N) zv4Y(AvwHUidmio^A9jO`DpLE_f=;v$!7rExf2W!$@F4|Ei`7UNH{S4Gy}hVE6e4VB zbD?Q~ZlcnajT72ki#2Z~ZYN%b2+lL_JGQ7mR4WLw(%j(rl=_lLQ&L?INZEN@ekO_Y zk3Y>}(m&^rmt+_ON_*yYymYnacdT1#yHMZa^7ha+_5CmK;ud6E^h#yX2xme8cdAy+ zKzI^_2+m=Ju9IyW;n0^8KHPcv6}B81m9ps9H;qCc0$OY!;mV>n;hqmpBU)v7u-Lkh zn5IHRaGzp7(%vg^h$`JR&XTJYPz=E0Cm747ZGHE>stq>qFuUi};O6HA3LNA^d1iSDV+;GTx;f13E^tmuIc_4d=!A1%1{p=!d5K`MMOQ z83q5&Y@M-o@tmOJ+o#EkO)Rv zuie6@Vz+Eh9nUMkO)=`@a_E;mlOashPKBYQ7EJfeU!l^o0?Szqh z&cA4l*zKuw@ns&S`46cre3FWb9+%^CZp?5s#R>lxmwV z3i$g(LR>VJzxX~H4sq?r!|g>UKxg9}Bwr(X_qi*{9LfejYPI#1UV`fd?T_e6B%4!5A!93w#0m2}M^zuUG&PT&`hz`xy>2 z%peXKO?N5xtVF>mQ65Ah(*!0|jY+NrR~#;Y|A$`#-)06hp#zq4DUz%Eg*XH! zTCmoP2?ay1=D?X09x+uO=&5*6w`Bx}cyc6(i7(}8S_mVzUi!m%N!agFo{_xLl!Qe01xP?#BsWU3vm74HwQ6kfhmL_ zGgrJ*)WO{a{>d1CH$6{Dg_yppDL(Cp&CsGC=ct-lw1K=m%VQW zRcgTEkk&cYlEEY5q~KDqubTYd`kIkDAOIG6JQ_<7rPUU*_a!00EDwbog7HP4P#H9dY72#CcuK;18xx| zumZ4yCN2;PmFjbV^F4lEn5HB?u^H$1PSJ*LM2Yuo#}pnSSX0Y(f4!oB3uj z7)UiJ8Dj!^q2dNTBmS{b%JLRPDb4crYovq9oAhttq_t3U7FRcgVhv2z5gL6Xh_}If zgUV;=5|y#OjWJ#^ZogGjNQtf@1;7+OBRmZg?w|i;+XNUW%NJUoIjq}tX3g_Zn@g3sn8z_(4au{S$_L&ra@|U%M^XBmPwje{oEpBP;(w8o*Tii>@rW0V$F#B|en-S1rV|!B2uIz8UxTd^|D8 zoeIcho@V6bzh+kz0qBS{+wskRlY#%%DX9{mlVgmsUjLfiq5z;HOx31W|6)=AW}=>y z-I6e`?O(GyDGqc5F8cq)<-gD3{|n;cM05hnmF2qAI63>%8s_;H_E`qssRPI2AL^#@ z6p7%H(Y^Q^nb**#%E%h^D}yD3oOx{!ntJxCSB#GzX->6c&PY8;5q~6H0DKJizmMo2 znjub)Nb5*_d;MroY-6SIwfa-VZM@$!{=yN32RHl@Ai)+6PS0o>Ati* zg8^Nj1}O8_<>vnlbLvmbxp?f71I+mYlo1HR|1qHd(X<>eilb{!O!TI-r>+PhlU>j1_2RKNUFrH>Q?_320W6U$_o5 z!vmVd5_~fMn`Y&#K(jp+IG7wT2$?XQaf%tj~I(-H^W^0#i^w?{?{NEq`N1fKE;@p4KY|h(J(eNGuWmZh^V}lc#@2 zNbyNu=>Jdp0%*%Wgk17z=kS?x^vRU|mqh;af9C(m)BiUWZ%?FWHPh>840vL%;`G#k z#i&&$+1R%!iCI_D0Frz#$zgrC%5&22#^tv085I$QXXY|BFt+un?2Kh zS(loUG73BpT)e!zO!6lu=Y^s7Y47bOn_qFual5{X$Dq1iKKVI5JffZ|V8oCjl#~Ce z2m$P9lT7ThO-_YIf8vtq6XkdLI&Owva94V7GFgBEfx4(cU!r-RyoDw(UqxutL=u0I zBf!CH1y~`2^YZd;d)a(nq)DuLWG*DHdO*E0lh}-pQd0-a7Bi6^VJna2X?E>%K3jXG zm9AdIWkzj!;nI@UGZjbFhWTzOLZKq`HNu=`V5^upsKZX;bJ%C!liU4kNM08|c*1Qn?dUWM(P;X;V}5 zA4d&Me1Dxx|ISEp9s~q^(H*C(yL8`^KuSw@cb#45`K7+g7w*976mmg@)d76Ogq+x| z8thlceZ=S9+?QX%?o1J`i45!m(1}HX{(fEZSFv|}0^OP35qsY)tV~a2UUIt)m+Ln) zt9Cui*2wazn_L|&9Z^tTcOoxlqPQAb_mvu5FapO$N_KvC3ogG{8g^zn9_q9z42rRR zcBGmQYJRzsKFMhFlSN6)O8YyDqKK{mEL@9Qb4y3Zx5vOb8dO2Np0{3<+qrBn7@<%;oP;h}h~3RWo9DpP z)YPo>jqA}u2J|Le1gxzmk}fVTyBqN;({mPC!#)3SqEE*n759tee`8~_=L5Q^V}x_F z?l%cRJjp+ClZI^mJ8`P9z z*Vm2u(>n^ljiS;08c_?CNi(J}(v!y^28?;PD$@KD*ZONy{3MBK zX=V9f<|ld@uF?pdGXlHL;z?iLG_-EAu#z{P{v=zPscc+BYe*ln5IrMFqFVb3#7?7t zy$`B*#f|dUxwBLXumXeYyQ*0D%uM`kJ#7?_=in^g=UYTJ=H@~x-)G^|w>l>j`?&95#j zYsvu&XQ?37^G}cQPg5B%jaQ@vfus@m!S;>a>>d^lJ^^)y8a`A9qv+S$INagPQh@tN ziLcv_FIlcf+}iW-RfJC+e&OnF+9XK`Q@90m2`!&u(Zs)L`;BQg+IF?PFB!5eye&|! z3GoE>%sEdM&?nrI>aVLXN(!1n!)U)5w%fC{^GaLLn7u?zSyps^e1Lrn&t7?axMEdJ zv4)%gE3Q+R;X1$m*~E+Be!t_P6sQPF^J9~Jph3MX}v9X@jyyXyI?bH{eE_%EYN z&iCJSpIn`-t?hiHL*A1*Hei3UHM1cjZ)j*}X~VGT*$_HzRWD3;{G9oCzisX8JZbAW zgb^s~Bl&?3G|UFDcF$7~`Zugep**ps6#ckm_aM6hJ92Y*xY^!cbb=$YvKZK$3*C1fc=7P!>-}6XR;0S6P?)2z@Ql? zt8hwY4kYIX@Z+1q*=ni6OwW>Nu?KfdOw8SfyJNSRn&Y{G{%4}IHo+_RTYSwI)B1*C zoloIm&>JAmSDeab`5SK0pUfk>q}Tci_|I|UI#z>XY%tt0ouKhYJ_?mk+bo-?-2Lx&vbb@I4snLjEts{ckH0WJV~r-FIHlDwZo7?q3v1 zJ-L^8XZXJQ=AjoD_lN5r)^AD+(2a(6gWs{AY$YUYoOYK{vG1`VcF&EScPE$gn9GGz zJF=C;M8kpId02WcsIwZ&MfC4V94M}d%FcCpvP3=IlC#p%v086+8}r5Bm>GzUj&IQc z+u1P=VD~qUpPB!@bVxwzEwICwDy|s>^z4j^U>6U5H^Dzadp#XsB4N=gUw@CXwxpoXvVY$cBvjkSu#pZL4CEz`1&5!@7wts`oZBfF-d|iZ^G?4AsbG_G z`WjgMc!IEzyaECxQgmyUNWi7DrP`I5SMQ!50*3C@9pI4^B_$;tQ8uiEH^{v=C|7p) zo@R<#>L*y6ns#xh)LD&EU0g8>uVGd^Yg}}RZg5=s;MunKoPvsKCWe2xcz>a>w5X_$ z!|v>6TkIE}TkA|}`xzsEjXtU6z+m8HHNeYefUnyt5qMtk`bX`Sm*GBk?<0VP|7QuY z1)(*Zjefda__Uk#>ndw{Zmt9Yxe-WbspBp}?8bx?ZMNFy!qM;BldN&lPnXgnKX>v3 z51KY0JUr`nvw^n&=FDGSHr!uo)uz5#ML&Ok=qpR|@Utm5fOtn7?0tKPwRijueX>$G z^uJXI=o=evFHoVy3ZF!_15R@nz#s9rHw(?bPQ`w0z#*5rDghtd_hwiLUNsqa@+r2% z=k3m0zP33QrkR>-En+RZc-!~jFMUodGy8VOvMI(xod@Yv7BO!QtIkS+TT z1fyUV1Y`GpVUY!Lnv!L%k-#FlUy7LLy=1D;p< z?c1GiGK}9JGuPs!^5wDYsdB!vuhl& z)@gAo$$oXW%0Cdu3Ac3)>Jm6eApg=mBq`9(f74+Eo>jpv(zzB*a732Cl znf<^CEAQw?fqMZA}o!k|1_E$a`At;?b$u7FsWe~PNVJtRhQ-{L``&wdUk@Kk!U&FA0Y~&uG zYjw)yyR3ir6nM5lfItGTQLbE~ylp7_6o!^+;4(n#wDjvF_0trBM!k3L!sLWh6+{n; zu-~h)`?Vu0HX*cfrOf|i_3-ZCaB6y*x2U6|3k@5Uy{d{ay6^!uPQ|A??~8L405fsI zs@ibb$z~g!vdRJ`*CYTcnXGNNYpFOp9JMWS=(Xh$GKwI-5r_Z|?B5vy!j^?fWa|Y5 zajcvIyjz$0c1N}zu!H<#jwri)(H(pb8@l-fM1S3!p0AUBn#HYK=M;IV@!T7r4hZ&y!|s+ggoA=j`QvRvMAun6)(qOhy*cDa1Zh6+L`J%vN1*20 z%RVuW?T5nMt6zi;ryc}cN;cNvqXSC^D*@m}fy+C<;QkE|!88%;IKzUCYfo>YU)Yy& zbai_eP?-uqe(zt`D~E8rsy3t_zXHai4L8wNACQvGy4z<3KJZ;dY58?`{k+522lZv< zb{EouQ%4Hl2MAY{$Lxemz}zY+mB;f zb_+8W#alsWBvi9FA#gZc_l5o8y#?#p6zZC-pdSvucDuWf0YV25UiT-+VBiJej>d67 z%z;3fFN2og)z0NctY-$5{)$$)UO}wyCOsf7G$DTk2Q&b`sV4B{k1WN%d6MD};`PWO z=u|H|M#!G-<|kUtTcPa1E!#Qf%tFq9SFpM|?Ja!kZSWljVh_Q6=RGNp^qQP=uI~`6 zvk^6(L(JSLL)4y?b|Ip}^4=`vuX-FDKF7tK6prtQhn^Yi3k322UbdSl8}A`CDg%75 zMm>oE0yL;2K*OGL4Tk&f$arPPgB0B!#tRD|n{j zc`1NyzqDN$4Rv(D6|0=P_(U4uZl0c_!%RDSd%(rAVdxJL(5xCzhKJEq+rLXFYCHqf zp%d6NTp*A^JpsI`WywW^Pq70IYT8UtT%3R+h=<+jm-j+TYNCAIG^s<)7XCk6;qD9S zt10@9s!mRo{(!G<_QRs`J?czvLpx}{cI01JZNENnMFDOdsHors;zR@cuEvTF5$^ZE zFeD^|)4EC*7r_nOzX)st?o-0a|1<;eC4P!C%YZ|6>K#vk*3fjPDF0!o&N=cCmTP}r zm4zrtj$v^^$^-CCjdckQQ8x*xAEgG&YixPfB!GmE9uP?`*K|TK(tnrKP%VVoyN~e+ zG%$wP$F-N2k8ntz4QjVfgdw9-@i_;aj_~ij5jq&#+|1V>-1$*EI&-z1b!f?~MD4lk zRdl@w-(Mh(k$KKyOvw~*cbfGTh|OqR%mjdwCzgS#onaKpu|}jD;9R$jy}c45HeykW zYCqHciaE2`o^KC1$mgKB&+khGx1UZ%J764J`Q~cm1K?Hx-{9aXAD!$Ma~7#Kz}q7$ z9yK&}RtDVKH8O{QNJ!+~0e&r~oEum2UXU{afwqeN5zG8H(YGc?1Qtd@Eo-saZ(a`R zFRxot&L$|JQzQ5jHUtqVYtltS$bvW{utdqlFYf7Z0SIr&T;H(@nBzD4Y6A$m zj0-SCH%0upA7K&H{yM3K-X%hFDNnuUCo{Fq=bfPd%*6#@?o9EMw%_FFB>;Bnd`W;H ziP{D)!t4)G7ZzvX>^nOCX1xX0w{@l3}#3{G>Kl?$mzI z{d_VF-CaSk4vAfM#^SVsNNgxBR9^A@n)W><4kK zAJFUv&}_ue19 z_TNPm*a1mI7v^);5Fn2j5A;k(oLKQ2Cls2QfI9^v{3C>3|AA5D9)l zqIgb3&xj1A+m*HOZw4of=t5o5I~YiEA_}C^-1DY?Gq_2BcR`-Z@8SVH|G9eU-yOPj1ybT% zDr1Xi{~Ztc*Ke!bK{ObU1oA0oZ@El{(P(fY_16pPkGipMDiTSh&J)xPiwH@@=GseX zj<(RHIran$2XrWt{^<5U)fYXj2?U{KT--t%#g3e?*tw5HuAoq#zDBm^X-TB0db7<2 z`>!0mGKpBma^V)P%tm;KM8PcacJC`bU%;NkSwguf*Vw6L@2pGyrw=hHsFDnCZ@ z6B0f2b+|S(vVFDe3qRMJZ~H*e_>JO^!PO$5vChBg&mVtIJu_ynH8N`2lbq+dsj^R} zOJ*V+e(2}{{XD6t(qgbJdhE-XJJ3>}Oo_;Ej35dIeh?Q1{L$-QeuzhcqWn8Yc$zHs z`xr-4jeLDgS1n2FIA)A8Z}B;H_3+m-#~{F@vX22(rT7$n_V)-0p0#z*@o<&n&yf|o zy(8+ag*%XXdeFX+@vqYY=13=pF z1;AqOSF-I}7}sJ)1sQsLx9~oA8%-3;q-PL$fLN1yAP!#K#_joU{`rqTpGJvuvwUxi zy;Eu|Z-YYfcO*jIyv(W)y!d6q?L&HWgV=&opn8St@$e$*zl8h0S7y~3Ik8#Q?@s%x z&BBiN6*Sw5`eIw~*DyaDzJt(wiSIuOAF%+7t0}%o@R$AgFALYliHF*L;CIICm;z-_ z7+bj8q^|$)Yt!r#m-ZhcWI>b{voneQ#&FVbQ6q1pn=GAW?A6!pJ$>%{O+WcHcdN1J zVl#rt7z0lgM*rP2bMRS+Tz_^S_|!MKHj{WhDvI0=D|CP$CTQ(}2#{5B-Tn^E|6>KS zaeyW~ahDi>GtvK&-~S+n|1|7LKvErsMnL_$Oz3|-XATg=w{2Z#4}P=o|JZ^c9>8A= zim%iEZkGS}AEXr!!&>dMyWbS+{MWZ&2zN#<=li=2{;%g)0y1O`$C?!TyBz*g7zTl& z1QwK6vHx0xzm4fXVV3{;mN+@EX%;lc`2Y9kyahl|_8gniuix36fBcUH5i`YM5dDo8 z{pWLP5zxaYY?u1~zKQ=Sc=wOz{I9V5PBMf3{~;_I+S#zU6R&~O4?fRZ^$gUc)L^%Y zu07ufSXjEuCJY!0_yMdbfebgofc7daSohHOu+A{PX9iXtF))h+?RJS{#6M$D{W46|9-CnyMpx@I50S z)6lJBeJpD{mSj<&{>Nwm>Y?$&Z1=mHF_$kpslae^`{Q%C!u$6?eemjE>x0J>08K5| z*oVQBJ7|(Da#VpQiO(UlSG^vNj)|k!dUrHNj&AYI zlTyH8!H8uz5$ZpMC^>dKi*W}|T=jk+E&p>BbPx&Rz`}S}M9~5&+M20$0nR^b0ZKp| z2)q1`P-r=5l$Oi7Gk4PB&qm2w7#Lk|scvF@8jyIwQt$+j!xmtT^(r=#-GP7c@PGO1 z=)M8<3&w`c&2u>jzeJ2!E;ED?^0@t0evXQEDl-!E7mF4&2$+9x$(5qSaYm8R|J=ur&R{;41m zXJfO#WMZCw^0E^N>inqcO8@6263=!4dK=t-S_=kRX_{=(Att>9jOe?}iVi>U&@%6F zy+41$S{3*L35V~?mS@IznR;Ucz=zq0`cQ=_y`zV~N9=;nW&Y`K5Yy{LoV5A)mRC0# z6X$JNwi`MS`%(sGwWVOF8-a-fcHLnAbN}MThQJzqOU}r61Q5awbwvIb(g1irv#(Ce zdMd_TPE(72O?ed~@@g(ZU&^9@R+?nZ2&WI@2S(H~&OZ}_i9=CpSV!{bMb$$@1EsO4 zGaU)ZLW@Z15KgriA^HUMB*GDYub|342G)b3((>7iDUjM{V_AX%sWZf^niI5VOn~KE ze2V?&{_)bZfv09DbXdW;l&fc!GQqZD**?z3T)ZjUy;Jn+ zbA3u|KcZt})HG4|(A((@KjXf5#0p7sl(p%TtO3iog`H5&lfA$o$ zk@QOHwlq_H z1AdSuKP*T0uLv{ z*;zqXZy)@8L`J3*TeP2*6eV5K_*~(0;qn+2d3@^V(BW=8!7G8-4m~-pOv>@q&mQ%C zLiJT<1(d?)x02`$_JI{J9LPJ)n5pwR#TS$d6>bedAfm*)r~+9{Id|~Rk7QE09tmi( zy=uIK^1zS$U~{#ayJY@cA$1M@CbG+Up}XQ=JD^dmLLtN}vJVUB8 zKUUNBoY)Z0`w-^~7FpIkTqv+}_ka*gI(x^4Pqt2{?@R{2vEa0Rd*6(1JhCPcgzNN3 zosK_enfQ^KQ|yVf%}Nz+rXscvJ?G6^Gu_e&Mmrg#iF zs@vSkALd^yyawR;*j}~-4e6>Ta^vgSQHHT6$33+abB9^>hV%UEY$VAfN+M#lHy=I> zKH>Qq{Ssd#%wDxyXfmod*Y@zdZto7AmfI|B4;25<8G zWmCmD0O7YS5mNcu5D>)75MVGUOxJUvkeE5;@gYgaVf!V=d*|B7IauGfjK7g{ptL z@W98hm*=|Pf5j>J4KCYT`S1DCD%vWxCcEj!$iyx|9=9)aS`X$b8Wz*EvStbzCbg`s z+7)9DSVf+xkVRUnuG~bR@m1*$6W3d9G;*lda&TPF~j6u%E1FMC2<) zpATXbK8q9(ih$XLmHmZjof@|4KBm!ynbcd{W$S!RyQyxP7JILey&-q$0+D~cY$3td%WBR|+!Az#3Nk+hdH!p()YN`|?7M5>^ zs;!??#v_aRa}D3WoGDRohs*oF9m^h!RqZx<(~-TVD2M)MCh^C4Jc^Hle6=*J#C`c^ z^lDXY5b+JecolhvBDsMA^^nn84t&+Ox?J&`w$4V%@j5jSJc`Tj_NP+|#Oe&TPjGD; zIqPwgpP^zBquBfk4FloxOVS-Nh-FRJ^jD?4w%2A)Ima$%7+Qho%yAW?F_X+9 zja!Ye6E%7bD%1NPs{R9nKc)|0VQ3^IudO2L9j(j5Pg@sV&4y;(WPZA@cNLXXnpL~Z z1!H4-R4&ot#xyndB(j$HNvDRBl)-YWa`b}-PoL=3PfaF(#=Z12i-|}s=WA_-!$3X0 zHZZ|C{I4gIKGTgK&EQsZr;EeEmlZJzpyr7TxS`Z;h4kQK16D@0iJ;1s)}SL*Kq^?M z7B(_pT;Hk>d$LV>6k;8N+^HcARWupcXFHrml4jgoBl*2MA`YT$B}6enMspki7 zq+e_H3w;&|OQEXRceLHc!`k?Rq!R(k;FC^z?|Vm!lw}1-!zE*@7*B8) z^(W$sCz`n(W<9gu=8eU1qor|=2UBcb_2em(Otx`5KEYM63kM2wnD)uu48PRbGpnA+ zXdk-H4}0Hm=S7ZPE|q&Su@QB+oUh@p=o9K++}ka#U-VXiA=Op zD#}?LjMARZRT;W*U`q{lq9qrl>X!!bFO_yRDzT!q*)XK7jYoCsAhpIjp7TbFCB9!W zGlEU4i%pNy1kG**bFNDU?PKh=+hW19{Of2d4g!_q$GI-;MIh$0A)|V!s|uqxa=XO38pY|$GE=EnkCng1jg%#A zF?MNx&jlOIOO|HrE)EOTU%!6ooS?#kU(dors`$L1Gb50l#&on(T(6Gd;I@q3 zy3<(t!Pgd$8i%5rYE0r90ckU}mB7jaD%R<<3;Ij)N5`>$;`8#DuMZ2o z$}9@S-63U%VAN$ladtWyZO8ezy-f86Fn+a-b*CFAry$}}0pdJN+(r&f!I`4j<+c;} z57sstCntAVx~Lx0g3iG=`KEqL;0hh4Ii--0`;=03x->q|b*j_{E_Wx8Oghi{%#K4% zt0P-B67^>6%$j6%Tid^j;T9|f&^O?We>qoBjsU^K&!_NXdp3JE2Ak5+)N_JH5=y#X zW03NAsm5%F*E=5S;gWhLOE6vFq}Qy@_))IpIGEDyFHRvlmnt?cNs}`(!C_q1N)h}=NR$kGkWEL&S)9c< z2MYO+yOQVjp-wnHTK-#g1qqj>OFPxlBkPqk5K{YfbEc>QrJ5{J>}GvFeQCqX@!N;< zhoY2ahdU1Hbi|%U8Q0eX@E@{U%>{4S*~yYDpF~M-k@`SL#3W1_s#2K z5jo&jGPB28LoaSPop?oRjh%-8+#sC0D5cdM>!{5oSP~r8s-wl%}jR_XGo>!KQ{NA|Q?* z3Aqgh<3W{Ct&O`|DYObdmh;UW4WKzhBt*HGF%7uZbQ5{~%WYA(2k$Yl3!S(hQ4mFg z>fITL5BCJk3xgb%fs-wVl%qKB-l zyoKv6xe47tBpx^{`1@nC`6Dyc(DleNe-c~9gv_5>?gRR@YGDKDI(nt(xq2UI z*1IZg9nJM^vnoX>d+YaN>A|x+Qnh8K8%jvdu@g!mC1b`kxnT{=y`7%|Y>frmtU8xRHl94e z4sUcwMM#j7T@I9;^ZfDo^fZ0Fdp%oCw!U4af6F`ut}X06+D8W9k^+(#CPdQup9DjE z4^v550k~u|Y_8PLEfv0AzB^Aj?dU3L;A9)JF%ZG%q-`T&uaC9q+!yo&cU~nqK@q~O zJGOfFVmzxo^W&FC0($a*&9r&VT2EV%Whw#Q$op zi|DJfQM{Z%$qq%mj} zq;A9+|51-Yq>@h*Ei*-&F;*k0&s|u@qP(Bd!@AFlPfq@q>Jh}G76&9}OE<1c%=f>&&tnt&VFSRQg=IHDZ-FHdfMi*yO- zXA{j-|9qPK+6`QQRN%@A3L05`Sr*F|_*UryWODTto951es<2}YRI4;qx~}Q*I@(!m z$^C-oQA?xNwA*%vy)PQuKRPH0>FYRWM>_3YHulwjzl=kBo$`S;|8oY)Wct-UVuho6$g+L)W> z9IUj;T)WXUaz43*#KDqJwRJf!R=1a|ZhNAM$aN|PX!6cO zK1~~YlnVjq$ytptD}Rp*BeUb(prq3^z{ z!9w_`);ShdU;dgJ4xZ~q2KeFQ)NcDg?4;ke3vY;tM-*zkkh`coK4 z+UROAzB@Qw*4ehObN$V``Sq}g6RG3Slb(0;tqV!EDouW>!MbRtR+%LonXm8*CIHS< z|13LN_P$1;%9&F4*Y-2hTa#bA^I3aBO4fMs*eW4qnN`D+UF~Vut#4)vvHW|V-_tze zZL*ON0^I(HwfVIN$Sl!uG?;$cQ4q3Yv|d`v50lj{GL7~=4`t5Eg)bL`q7hH!*(j%H zni%_-XkSG@ItAA>$zM)x9M*KnA zN7DZp<%~*C6!7L0Z4n0_#SrJMP|)OU%{kj*hp7Vn9V^8ETnQH^MLk=5z&r!%sthZv zjEQ*?OmcegY+2xuBj5_N%k_6J>8|3FGLVcz_Bh#R+cx6TLFh&+DO*Rk9}~j3rCsu{ z*df*Q1e%Vjso^Q+bz%*Qqk$dIN=c%Yp?ZL8jQ``h4w%H*^(YggL z*P9mb(vcxMFE^#q(iG z%A&i!us^*t4dU2HbiPJ=-Q!c&Hk(sNUGGIK+v?D>(rDF$H)0K`enx-CDeXZ03S}9s zQH}T%|Cu8^DS)MNR&FK?bYgmPBieH#eIl%HxObHCQ_tGC6#u;9yV~b{o3RZ`zy8tU zCp@`QT;EdCKWtXP=ghbo*ed<>CkCw)GrXxCM*M^-N7#z{hFDA}-!r3d=xkkRZOfq~6;H@`W!|}mMi^yL$BJoRpVj8 zL<5f86^hB54pFLD8%Udm&l8oT$`f=mhkG^U<x`)oAyGq*;_=g39iu;yA5TVD{yBT=T3HZMX+27{O;!hC-0n~La`ZB%a@^1yA#GM z_7QQRmInW<1qegouGD^C?V5C?#aitAsz5_OLNHfay0^iugL0;;zZfF$L8+a^XWNr! zywvYe-#}JBxZ1Duqg$^7s{VmV4s;^;m4qtcE`AIQL{^;+d_6X;E#xM}3no2I^WuZo$g0*^JJ%&)(@!syFC*>Nk3zx6LO% z1Aj+TvvYk3l}c}b;`Z>{xVd~yu4^=PkWGBC&Owuq2&Rifdl#yuHKBziaroFxlavb?rB`yqwc< z!7I$OZaF;cp?m$PE!Rm3`)8y496qc5yX(Z1wr=KhK((g(=6=6!rI4^}^>4qfY|vb4c#(VAeV7W=nl6 zZV%eRO5xDvOG(l;5!bp~x-u^YeY@Ptd)Iz;eqN%(VtxLi@6pK%d|d9E^Pf7E@b-c9R?EC-hsAA2FF|X7XQrv!l!5b4tv;8ev@ES%JT!aCcF>$jA1h z;4B+Azx!o%D7cyjZa?sHx-Ps3Mcr+CMW=v|uK`1XdLO>W^LCo>!N&7WT!!LzM*T!b z!d6gc%_)dCY&^Vf7iy(Uh(n{pOWS@d#(OXh%mVkZugG+A zg5%{0s#gWlsnCDY3<`kw)hflAYP;m-+M1~v;z|^JG4;Hj^{-yRIz3a?N?{Up6sw1} zL0D)ktPA`cBq#YS#pXO@uVr+k1=OZR3N{hD29jgiR+^N@=cTgaNA-c(<*Q_nn3rsiHt6kRYf^P(U(?EjuLxZ-*YrnCGRG z-`#zYWD%mz$RJCJck79TW5WdNXv_}jmB+;1B$tn)baVC~#)C-&wxaK@Z#y2bMG{D* z;`F~i3$zY<7Ev09u1S_w&eb0h$HW$m6LONXEEeVyg_>><5(+_ze?cJCJg;Q}>+h~k zct>pcQixo)04d(1Y+ja$>k~SmrW7~dR%iR2E^wl#Af(sx&v}(wUqzc^wl$ zOMi$JUB;htt5A}%x+|L?MX=b`wx%~oI&aEoQh!704T+sl_a7mY4NSpxK>yWAl-(w? zyM|OB(TB3M9@=?O>Zc2B11>%iiP)r&!nY>BKCQ$5KLEqRpBIN$z-uN(>; z_KrVK7>=G5vAJQpIE)6B8Fzicam{=2nGSma4JoooJxmPi3z)g)o)+Pa0sE0FVNQyT zvHukJu$7j27}AsX(~ROIevbyHUiM2>!}5{hK}>HasCtagIWX%Nx5RABzI+=R=YdY3@XBSVo!3!|C)jx|njPZcJ@qG*0$Kl3!&nPcyM zi3T1!g?2w{6(fbJpd6ZtHY1USyClPBdMXI>h!!EiK|am88r_)Q|a!JO;+OO4`1oV7ePDop;|Q^U*|gdkui88HSihaoISpZmwVJ)Kih0u z-nsI2B_JYC&EW}?)EEop{M?!VmSn)kCxZA|jx)DW4@TA|r)e(u+2O9-=ejPysj8BmKMd2GSo% z(2JWRd67l_g8DTH0nL6b(%t$nGl|jeR^01##^7O#lZCw#$@+Mo1;YGNVoVE-2+Bp4 zfS+m|JOHj5)==cMEPvBbssX0dU8U>l=^>|66q^=N*>usQ-LECNW zgBI2o34zJVZWcqbU;Bg16Udjp5G+E7oKo$XY85~X0RcLGZM3iGi118m-r{9C<2R~T zh!e11C7?}F7nGsab=kpKKc?|Z8vZ=w6LDos+`%}j>gpa~>Pqx#XF(rE>o@JnQGOF- zFf2h~eG`{^79JOe+S8UMEFYp&BhD$&-H)uI6Mglh+boI|cwvt!b(r08qj7bTc`{lG zvRp`+c~E6(|LZ(z%_pB=u6o(xHDP*Oxmsqi3GsMi)M2Klr6M`WAUfU{Ysxvl4KpF#Y!8w( ztr--x0dbwW3ic-kjQqLgff$cH$@L#Ve8X4-2O9a!)fAC{E0cC3o8DoVy%ktdK#{?c z3rsAc=ow;LvuMVa_U5^M?w?5r4#V<>2CrI#Nl+O%iyJ^0iipv4@m?!W8xs$3-F}KB zS&oR9!F$GwZgUTs{@BJ?zbCME0*9KJUc)8fVFEWbkvz?YAv22x7I=&0^Xr#99@EyhC|r#wqR`b=83J*@7jE@PD)v%@?1s2 zK5~^Mgyy$GiV|`ZdFkjysd{%_lDT`sb-#40>aK1*?wrGQK0gZfArH9JEb9RC8vBP1 z*2(VQ`*biT*#RZHA)V2cB5P`0cpdBd@EzAncp93uu4B~jz?1{cq`KPo+tnLI|uL8U*YV| zXOX!IbfDEuxrvM;&vyhe@&jO8GMnt7s(7>#)*tCSB>DW^eREj+SDwjW)4JgWDjr$Z?*NnfyC@l3k&Q&5*d;s)g5eK?Pw(HOe;5mvM6rGa>sct!C- zUHkC{Or*Fyv4!8;Conuf9gz}1t;MhLS%Wc2OU5Td#(vUD#{)nkby7!%c1VIm2l*DM zd`6}&E}sxJGPqwZWDgL^hNM7bs8<4-6bXhFq3^qANA8JGPPH5Oc8Y~7DJJ0X5(-}N`G{xyd^!hjzthQJ@%9o2@&Di8rX~hN%QB8 zeoY4w;T8E9j>R*C3%lKeCkyF>tNDFk9{9wZZ- zO1;L^F)~T$Sb(x>l=Bdk?w)||(IMJ~SJPWUw9#&s6VZS9_WySEqaw)fxL9%}-_=j7 zKE^1d5pmW<|6BTleL`&1J|rkjnDc7{1%Uu?paU#);rt7#tey(}N8l)dV$gLzOb zVWN>2yE&BabIBtGDNDH~V&*Di5%Q>cFz!V5n^(c}*;YtSP7|Fo^I)Q4umlWja(v+v z$lgTFRTq!8_A1qx?Fy@mt2-$Sut0%4t|GiWuz_vOrVqY&S&m&xs7NHx7LpXk?kxN&^BL0k6DOrY7g6sxTL(C)34y0~i7M+tc*bg}xG%3h9ybTmDQSeK&$>);fDdIH4_3oES z2Q$%P%u_ClD+y*~ZNAB-TJi7MSDl-73_PZxJ3&RG?PAhS(#+kJn{oF4=Lrq6>6-~1Wj*X@k@Ly7UMy% zid6Vs@IB^c{AE4*DJ*>pBzYNmI&!#ewgww@(p#__^(Gzzibe|?OfOO==asUHIh^t- zR|0xxcG;^)a{;=8bmL?VuJnGk4v$Qh`TTfDSq-qnlmpk7p!Y6^WbtT_XRjbMUUbi~ z1_&9&sUKk{E_ckJ)Q1UUe?(^Z7O}Rzh7>z-cT}l?glyx@y%O9f^c7nK=jO8&I+8M( zu#j1&e%GHt=RXDWL@JOFP5PzFQ#llRh9x5pFFGta@_PwLIe}%XNKZBEH33IeH9%lW zV4ojkm>S%~@Pu6ppX3BmrPBe%E@NtHCD6^bfY$Y?kj!QmiCetQm#R@V?er$>H<KRD7dBs=c}HRqt4+1bCHP3BC!pDVex5On4ti{X{O+A5UR#?nhb1%i4^2XWYZ(V( zTFG3|1TTs-j)H7F6!JBRa`g|8;aUKReHO-+t10C2C;rtFjAx;sbIuBrijYhv=@Z7q zks+J{6AydTbgZiryaXqM9>8@t6NKaMYj6A#XGmmDNGnTLD=^N)zpE&K_0}@Ed4d*2 zd?*WdQ2QC!7@$(#Enz@xLQC{sAq{m}Jttp{xTW~6xPb6;i`R7nlgz8!IO0q;&O(pk zQwehDRZw^&8i-3e1{A}x0XaR@d~(tO*P{tQro&a!k&e_$?ZV*YDJy)zBM}(1nd=*c zlKU{B2US(0Ch+OxUgPZ6_!(0m&j|;?Yg3xDd0Q1gH9O=wJwB{QM)K0}yve^H;v#cn zzo(KcXq%1{bU@NyhBgz>tmK$VJ0X#yCPDY}#7OEJD1o`=JMGu!!Gpepbogu@uNk6I z6W_yB*?hBW^J$ruUkh>DgPxEV*qcMMA37?9_n%~a$yzL=>ABf`PZ+@bUXx-+{aJrx zbciVUzRY`kmt+uw@>WiDPA$eQWSHLKk0d)pf%fcqBhjlT%~D?366 z>dP&?1Cvd<# zur>M~KPeCY-D5EJLAPUAAf%Kam_wxI#5}y3h%i45%z*Stkn2Q(V1OMd zP&$H!4qMKDNkR8#*yKCUI5Y;8d^#1B!HF;Z#)a(zoIN=zPr4XI9(@z?OvA5B%pjCf zQPGLILRMDxBA~sv26GY4T)u8?!VjRdw39FXOcR+j?KIu}b#AyHdGLFHDUXyF4r+H- zoaF^1GQyxdc+NWf4H!G~EspB)kXRxl^7@hLE-+FXMftHJ zin)AuxEg(=d1OsI2wtkw+xBpbaX}t1&l*UkH4=x!z|~lIG>KM;(Yb+7%e?hT+5{d3 zYoYBZabfhNSp_k)bb#@fj?QX^cAuM8D*#MKN{U_noj_-{VnhTR-YPU%mmJnV&80uA zae-vz$w6exQ|&KYqjUdtF@>Z}3z8}uI?rpaLF9DIy1buVstyE%PM!PKY&4TR>3kp9 zSu}^g=^FcCnL`@{7%|`qD3msFB(4umUK3frSI=>eU)?;of1&1&QKp7(oD3rWSd$(3 zEBRO_V-z!_4~zK|`QLG22mYlK-fC!PK*%0p8(Ub5d}#z2_b$k|4P_uVd{rrkMNFJ7 zalrRJ)^~A~Hc{eZc_tm&u4(ze(%N|vVc&cyxH4tY&kkUrHu=%cnd9Fcxyqu#v zsFaiVByA7&mdDunBzU@#fq{;YO5M4^IZv&rHWD(HSOo2&j+|TyuU(`}_A?66TI>CW zs~L*d>2~PNlk5K~P$o5)V0tm9{9&kd2uULy2Mz5zwIH6C&208OJiqAl+nme4@T8`v zwgq15+YDpqRwU3DM+q(7@alrGsXnsN1L_yqf)+N7(H8xR z0wPl)cC7~_7^LQoRSMx&bjHJ*uh_K~W4P%#rA1Dk?+z-d`zR+D;FAS1+ zMywSZ!QSs%Im?DiMDFo=_+ZSYR5Hn`FH&iCSA>wz$zn5{MF$3kg+@dZXbs#*ZL2t* z#yF5x4G>Bj$_FU{)`%51vJIE#^9<321Z?_*Z%#;=MEFO7-`(hTXy>U|@5hKgqBQ?> zn12LqI$Q(qG2>*0K=V0agRLY9yu@5RHNzTEDjhD}xH+Gw8N2MsBA%=J%rQ8@#b^w- z0aQO+@zCBp1%rkY+|7sWPT7q`hlq-ma@y1HX_Ja8>wXEtE?#Jn3vvbFSUy_kpq#Oxt)jS^8(R19h)PJP9e+;MTUePezM_Iwa*qf$ z71}iErhziqM;ero)`7o69wG+XuZM_-#-x}@?wMsMh44beT(lbdkvbA_XstQ!s2;H3=$BE7*70k zP8$b_>++~%@9oab4rNG7k3ib{Dv3`!r!VbC2m%cyHR$2SDmmZmwUfA^sG5K45~ya` zsVO=aNN*S{X-i1A@Y*&Yg}BkBTVo@aS+X%;p(`q)Z;D zhGHai<_PGg?b&bGrYPa14fWM}R#g+oHOBbrA2i0_kl8Qa+eA89d88pA?aC_l8pGDM z=-o%L|HIZ-M#a@^+v4s7Xxt@0fZ#5H;O_2{;O@{s;~oeWf(Q5D?(V_0ad&rK=iGbV zJ8yjF?;g8s@2XXE&1E#i)%Rtjv3#FLp}CZbCqF4j1bxJ&KqIe7yQ`T4u4PL7g@JKM zBpuQJynWY^1Z5!K?3`w5cVZYma0Lk#e4EM3llUoEB5iS6=q-A;)5G zYR7Nu|M8puJ~^4he|uR$vrsufqDPB79K3LM`4r;+fT2B8ag#WWOedJ$3!?5yff=Shh3@gm^uLX)uN&bcHKF^ujM|>vdz1h2% zoB#aayfAUGkk(|R0{%N@Vt;=yyre(wzZQ5t^LOfMp!P39sgRCrT0L$4%hU@`_JJc$^L4*-e$Ky>mXxHECyVj z{%y@l0(tDY5&AsEKMMIPJuyao|2(CRaG4B$U(0p83KI&4EaJZGe;D)6#E>byJGt;N zEr*C_T+k1Z@c+7F`Tw?dJ-$zhga51l-~%2M{%0SliT~Y4=#RNK#*mK6FE-0}{yB9+ z<-hwP2;;eC+V20;p2+^y8;UjSIv1pWXv7SUw0{oI;Eg1x&@IY>D8UEyh>$iwnFgIW zbeqEfXoj1W}CqL1Z%i2#YV*b+>25 z2AM=BeGn{knmj_tCuqtnkDx`<(4h;(L3sWqeT)!&6TK5nQ~|Bw>a_zx=d_ay;}AT$ zJmk5bXE89R3iE{M{kC`baM0JjDs(}wPAg#3`c7N!tH50(!3Lyvk!;33?RIv%UfK#! z<#@U^uI=(^;Guw_R6mhIzqDKX$@b@XpZ0`Iiw>3}Fj*}AB$$Go5PbDE#thlT9b$1t zBmecUVZkUxv0=geD$J*T&lduJ=!|ep(&pjj7TD54OablPMs|T5;g!i6N#*O?nr&+I zXotAU^YJ|)AtC=QSP>E>(AXlD&HDQA4IwvtX8YZoFh6@`o-rKhJI;6-^I!SYe6&gM zwC}ZxG>g}FPf_wUeX{mFKZ)z8lS2+*F28eJ!)mR+D%V`CiV1v2Wje~t|f4$eZ21|yze zo1j=nLmQ6mgtUxI2n#hV0w}Nz_q!~uzJbBU@#d&BsCCUPjzGbmC$=;3bGB%=C{}#J z&!CywYr#85UDi|k3P&e56B=JhZ+;w{#DyO@buw92LhqdUTK!r@-z>u@I<2cIC@GQo z+cIB<`})#W6NQ<+Oeo1qOX%u)-yV@*{t^OPf7Z1oQdON^&@HH|R4Ls!Jw-{=9lLw( zmgj@)UIbcyjQ<8|J%^D5KXi*0cRg|uUkr`*g~us*QecS|DMLBTRK2DO+A&iwT&zam zegRuRJ$=gJa1@0OfdU%Yf}3$1T3<`h)m4I33VXsoW*HyV`!N@R?@u0)4ES;Lhm&CC zLsz;|hK&CleKfh$FVk;J6DoBZ!K`wU-I~GUjWPk`JbVKgGKsa zkzQY4nFG#{PA7iyqaMxK z2P+o#jBpwCKliyP@7>IQtC$$XvuWmmM-oq|=shNMS*(F%gal2_jwnCpQ{Yejddd=V z!Dy*J>q=9za(o&mO;xb`Ua5?6!e9Hj!O&oRYb%je?C}#C=h?PPglKy(^vGk1*;-k? zcCD>Kg%5Awr>?uStmwV~$`@#vP|_^_$XC!nRo&0!Orv~NFWXg;NZF*}~ux*G%pUjgDiRI(NI2x10ibiKV=!D z3{NJNNfo0jmt<@jM!srbNfMK2v8tsdlxjeLxp=diUDRn?^T)IRT&cLt7kqEzX8k9SDg#Pv{Z-mU@CIuT7Wi0?Db z$5IBc1$}WUp52lR?&i`Rs*Z-WiqVE_dqZ!Q0oY=8PHyg@m06hfD~%Dsz2D>K4&`pq#kOn0I@#m=9D%k?6b99sKNuNF5ZumvfC;jxn#;bp%w zM=UqkQwoo3ew%FF3Da)%rEG0$yGl~`MhFe!38X%sM0yv>mdfqLVmWF!%;?Ri!90=x zL!$Fm8(Kug;Hk{Y>9B}{C1GpW0WI`|5ms0zV69$hTFPG4Flx+}3n}nwJHe_`B0*<@ zMec{W>Vw%c#oBIZ#s9T#p~(*K2anH$rbT@Eb*d`JujSUHcmHdn3r~W-X{3!-64-58>`H^ix*f9gZP&z?Dj z7tcSjQQ#hNV4*Z1xLg}#LS2}vHlUfx;{u^{$1%;faU-&GRe?)9MRMbZ6|7~Yqw%2Z)H&Y(!dPX)>p zTpy8{H0`=q;F;D>b|TPi@N+(2-Px|8!XM;uu)=lC0Uvp$W$GW!~E4gQjb~)j0pXcnim8;fBSG;_07Z4RCcIt$@yk;wM9AggTickPq^Nr?#+$ zZBpeopom6@71VrlRJG18^SX_;&V+}8uRY3mBEC%}1Fw+5h8~9LA6iekWIRwPpu+|{ zDVW8?LU1))7_KW)J44rRdM04hED3-JsXX1flF0+N#UJu#cvkti$(7oHgwvKeu8Zv9 zk&!{AQK*__hOyobbDwo}iQH|MLwXQnji0|eF{l;1Q)lG*Gd*49vTQG(#)htoqJO#^ z8yZ6-s7(5eSn{2GzLmh^**y@Yj6WO$QDw8}XZDk_G8c~weOL#~#%{b?jz>PVi@=ee zOMTdB*3HV#tUY(TaT}7~><==Nc03dj|bIp$q`Im~7CWcrs6gGm;Q*7r@V&PZ`ZvSrIvq!z>9dkfs;&MiUw^=+&oe>yjS+Y22v zrw=-J^DDQ$%^vo59>*88;ORZSJKupbmvnxeiLsNbmD}f^^SqC$lVE-b9J|G5{cpxU z2!99_SSSw&j!q2{ZsVtPy+9kRWeG6n;T}&@v z`MT((LvR_26BRC4Ol93vdnrC?(HU|lLqy+ccn@?RbM}ghhkBzEsuU2u9GK$pgbE(Fzoa0*7$JU{S_EKJ6eVEnbdpW94(tXcA$lnq7aG zPWw>X^x!+0uLG~9^k_xh>!WF4HvdtqGPLFB-feT=heeE;3aw>Hfca|NJCvV7Zm}m1 zeUT42vO6g9-SH(TS}a*H*{ZZAvquY=N-hf#zYyTl9gt5RAW}8jvb?!xYz!WCPf;u` zlN6N(?aS)qQ_xADqOSJX?&wR}9A#;1m6qrqA~!$-hnlDETkYQTH;Xs!AM&HE85Q*G zJ+^FB9nd5|CamfNj-jAc&|Jw>Dg;cVJ6AZwXeN%)2PI&V0M&SJeXaF&IY@@kniOmx zsPUz5+~@yuK>8njvlR+S5H?usER$5HzvIs$VbubR>anAFF~`f+Q?OWioJem0x*Wu@ z!OC)JkA%f11R5Km#e-Z@d!OXz(+MWNC_uoPjZOi8K+i7z3o2+!*}B?o3m*+2h#}Nm z*JD#>*k9ztk?I=+Vl(C+)M0{Q2jN+-wEL5u@MCs{plWH;ZKi$$UDBgTU|&>U>gy zTC>M-gqD#Htx~sf-zD4VPl1fDiL9wDNTe;UozZ+*of8u&*if;cCAal;vFkG-zzxyN z=X7^+v&+U}M`!n88Z?G%sUwg-NtaM_CQ`?OPxbz6KBoWk1~9e3lcm}H>^rN8sLcmz zYFM+-xSpL6tt#UKLGxb)TZ!*E9FOKmYwed&5(^Xq($k4qjM}5C@JcUs2#l&dU|weY z+wdPT_Cb&6hZZ>Gt`pO50x{nx%_g%q(g9rdH0}4tsAc+O{etEU;Yv2mX94Y_%0C`O z5li*40$_bGO$`7_A)Z2 z>cwqbk7qDw<&v1=+1~el-zXFmM-oIcVYAqDtbdDSTmt?+euy3Qs^0_CqlJcD_xyaA zhAI{E5v~ky#qVA}UIgpVZmz**nSlc*J_5(h(TFBt8OwM5HDG;WVNiEjdV1equ{qhY zt*orFISuQ#PVbM&cL!)t;W#Nt&HaMQ#~sfSh^o(cq^1ACVe1^`^_We&KQSMhJSQqH{vP*|)coYm+EUsG1Y6=4;nmo^I>MbgEfPGIGGX z0RJ}r8mkr`pPdtMFT{=I%j{*rzj&Cx-h=;17PVq|A|Qb>SC?y8$M|R?*g006u@J69 zT&Evt(-xhNQQ2O2kr1SeYOoM!WNPilK7IdSfRo{a78q57>+aW#f>$yN6@5_8wCzlz zeD2>VZE4+THDH~5AjK{{TdFav2-prSR1MLaugG(WGk$)3f>$=q_(gD8&W81QC8|Bl zlH5()S(oT=t~f-kaGX+@@|&O6*;F0r?kGtd*U`bqgC>4_sdh1m;Ogr6kpqy5jvJ%S zT5}-!z-;P%RkEI%TV@x`4}?7IOtMpe2zTlMODO{= z&0Jod&X?T=nOglF3ZG%QPDWby>)8yS@~vqfiiJv!FCp6TJi3B3CXuMKzToaSYTu8W z<@L?O=%qSqD)GXez9=9igmO4*n}*7fv!bM6qfq(-;e?)VB#_K$(HM%_tnKz8f^hCw zYdImQhWS>5sk_Q}+;lWmfVAz=)d~;RK)`jtrgDBbg^eO8po9?NmN2NXQM(NM8~=>c zHdx*7$4e50l(jV@U2I0s2jT`(uLT_S3Vq~^C3lO!&r%_Q)SZe1d_nv2hM#kKV~BYs zTfZ_%NJ^LUHHUQ)3{P9 z`g*}!S+WLAeGxtQW3f)Y9u`HwQY{;bf;VFJyj}few%MYXl|5A}4!Q3K*Xo7{ag6OiRWFhv2?t6-w;RJMg2HXA| z6P(D~Gn2a=rdg73?cRKmCgE~@Spd)}fof*CB7O+Viyr5M|88K`W`p-P5VyjxBO=o; zbPH4v^BWhT>GVPOr-%mzX$yLPTdFqmPab$7A2-H$w#M5#E>62S!iv+1BF%iN_Lo;+ zS$DvofUG3|e+7}llC|emTuSG{G>H*U`}TFwrGrB?N9=`%&8Oe9_3_arpl|S8mrdl| zv?YtB1fm9Sf*1#OQSICGtAStejCc>=2qGY z6)BfdMdY7+*{WVd%?wJJ@OI*e_tU#(ckL7FXL~&rGKvZo?20pCw|08QYd>WiudW9( zIRA3l2HUMW|Ha;f8r-QygQi)jA61QO*|oTcT-l~eJ$;=TBuCQ(sn&-%S-&b(&k9EQ z(Ehjy5PmLViX>nSRU{R$g#J)233|M-&Kp0f{G^Zua^)SqhMa!^4>I#$u>p7T9E;_( z`^^QO?6RhIzvkcBfRO+#>x-M}EoN=y3jO!TYt=-1{tvLE&1lT2Y*mT)VBi(3u=Ogg zZaG3reWfG_!Oa#rww1<90^8|sbF$t{WVccjQ3LJzWI7~`U4ZW^rEdACaT-j7nG||Z zWe0g8UPMfLn2t+w^|3#d0e6LwH7Xp`Lca#N(cQWGFzHG;UlLs!pSt~N7cn}Wlh|1$v&IjKk3CHn1Pl34J_<7q2AY=*mju9T!IdSnqsnEYWU=5V5 zTvTOuhu=HUmm~Ld>V<=V!G}q;hl7h-?W@9&%xVz911=drOMQWTzTy&c$3lR8LhC2U z0uDXOYQWQENt_8>I#B2&c-7!jJohU^XK^M!6vdzVn3}@Og_ie}OmcY^%23;a+wc?> z$I))c+us7~(tq_7U79$aL)BVOqgW?Jn($iQdrAid_ZX2F&mv1GW(wbMAlwm0Zg0~_ zs4F#MjQsIN_*>W!W z56$lOJ<-ED_#NGlCykOf!;m`-R?}8Fj(f4ty@#j(rW$HY~-ala?9ZZJxY-$vGQN7A4DcITOE4pq|0Y}ov6)SDE8`~fGi zaBM)X4^}5hGDgcNS*1DZXT#5bnjSm@u{?)^p@*Q9lWoR9>}vEes9oyp{J&T;P*8S(5xe9b zRyr=G?1lHlpAbj`vB8 zr5hjQDy%SRSxsIoL}ci;-ed# zZulRsiO1PFcpS7N`xSn`W{)aZr=x)V+}@G$W17EWKHZi{RSTA&L2N)UZ4M}6fM{Ya zXcZN1zOphz_;H_`gRWMDeG1i!xiLK?GG4TfpD#%=0IruRUg6bN zbTi3l@t7Z6Rjxax>c#qm-XOAE3e&YRuT&uyVaaH+SiYLw_0qH0JdcL^!-)2vMte=j=2B3S_f&N#nD^IchO zBprTVb;##|LH)TF)il5u8rh*v67)h2Kf~*K4W-Lv0~*#~G?{-rVK6A*Y!Sjnu+wul zo@dsthx#CFjo^UEiiCxo{w^x84RPx{X97?>RbfUS5$T?FfAK1-dH*u~8j=^EyAVTmicLj0$v8mNS$nnDCx}2qmjvh>y;I(tVJ)02 z7xyTDR@BqCn7><@Fdev!+!g&NRsj}#itnZi-^AY|g zcij*e)ucJz<_Oy}3#~YC#|aV= z0N;h*#IKJ}M4@Ph5eI5gJv}_(ze%A|jpgN}Tt|iGln7X}n0?xbY6_KldW)|ni^ZAB zbhyaJxAV-ajX&>u!(nFizG~{!`C=T-l%lJMndgLeu)adb!lxRP(7Z5_zLCn=U}0K3 zT)GL1q^Rp>6>3Ci=S#V2s>fRe&bLOy)3uoHsUE$H;UhPwp#anlBZ-t#5MfS>+m`c8 z*=oRw2K(Lf>T*N2RR?yo?_#SH;fhNs`(K!#zur3g-(W)MUzor!?b^Gqv_NpRA^ybv zm87e}O{uFZ07^tJKJ7OmJofvcq5G$-BepbKRYIE2=95PkeFR|RwOkEHH#d{T2O9Wa zzpPvgEWTo1U;}3V)d{?<&t! z%)fhbcaW%Tr-tb8gYkgQ*EXriat|b5zgDN~Nlo<%!$nE`ZL@nhJm~6=)n8UCFj@2; zQI{JkdJ#Ve!;OYMCE>_tcG)}cvjtcLA`hL*52+u6YY9W1n&gfz|}8{ zsX5|d6(Gms5r=BCBiUN(`5*8|?E+3PFNiUpt-blFr5S3{&13eJi)S0XZ}_<#7RRdc zM#mHK<#rk6APN`3aYmw)2Pq!H_Y|Y#r_iD6T zoVb?izM?0ZOWwyg!eJtHg@~}4%r-06+D~CfDUpbQ$zE!(5w7K}ak24)TCOdjRj9B@ z7O&9Pym4_0W2G{@9jYP&pV8t-gP}FJNCdn|RYs+dVp@tR|;tgc`;3%;AC8|$q4MhOt+5B`RPF@3NhE@?a+TCV-umogL%)ATG@ zO(e^7%W3OTq1FyeBqkYimX)h}QvxAt6_bG`NT36?J+uDf z{4%Oo|84VeyR2!>vxxOu2>!Gfi(>B`=-*hnanxqk9NcEt6ifN`BQ&E0(`2=u^!)0k zs|Xv`w#%M>o4r>G3@d}964Et(V)EtBFYMR*7ntzxg%5JSJo>CPpofB73MX>RQpC=0 zZ+qurMwHtRkF?vIk%Ph%hO^k>M@+c%s;p-vqlpAaXG=7?yM&0b)PH-fExpu#9)5Iu z$vPTkNL#~2ytD5tb%-ad*?0w`*iL9IBf*6Yfcz%F)o#>}Xn?*}v~V`Zy9>*w!%GVr zY<6dq$GelrSsQ}oSO4s~LUn3k#z+m7_&02?k9A*eps&9k9G?Y{wqTeZ73|zkvT#U` z$w$Xno#b^iV4N(6KRR^ut*f%*sr}n`r6!r}a+|j3j8$ZM=Uz#C+u7wNHW%Re0R|6gGvwT>uI?HRoi` z>*Vv>a($MTj*06pnbk|RRUE?UyO?f3#MVBfB{qrw_D#p66A8p(m*!<6ea_2lR1KXt zfY}2n>$S1H8@0$`Y5F+YheWm#mGoH3b-K{tBlQYFA2@?}yf-gO5c4tx#PVJBrL@ZpYv@3|iSFhp%cE0#vuGjQPUHp;j`{-)U*6e!HYaVh*g-#MI_GSp{^t_<8 zP-XFfzuqB^6EvnjFPVbocEq1cppP8`qp9IE5qyx};99uYXK zCVQ_Zqe_8Vfag++t5JP@?QY}8GTPLScAWi6L}Mk}zg-PK2>ja@;N&5ptA0;_8~2BJ z?w@w>^7gItjEL|d_wKk+M2MW$*3A;v#VkR~#^x!?V6Sa(SFBy>lGgc4;2-)F95fa>mr>H6bJ?`$A4pama0g2Ve`V^tHRqyhQwV;m(F4}NijLe?%UBJ-=}XOY>k(9T*(pO^9;!V8KB5x#mss` zja9YzgZEDnAojO!Q_^yBRv+;a7rS8!&baM6!vlW?i;dtbES@Dsw<2xudh@B*&>QQZiV%PJkqHdPd_RE zNZ7HHn38+W*W1rgQ3jJ0qyyiD=wlENSgDt3CtCnrFhQTj4hb%Ql@$y)w5>v!9BB#~ zd&_rDw&qe?joY=bWltD;dimsNq$kS@hx;o*AB+&NZmcXC>CCE3WB%w~o+o}mycJvg zZIKH7MX|f{^HC6M95ABmMSP)o>Hffbrc@8J&HLICN1UdT!}eTR+qgf4UU6F%puEh` zS&(Ft3WrHbGwpJ6$Q8_@6H!+>pC>Vt;<$h1$6!SE1DE#h80Ve>FRS z$q8a^u%tNI*a^((31j1WISTclkK!=bx7a243ugN2aTAIA3L1Al^i zi}L+bUgxa={ZIY}Qlhr`q09G->dXU37x(+AqbRVFfOvYCfI0{#-dpswNUspT z-VQlLHD5B%Lw9R@Rod6z5!Db_LrKj#sP}#X1@8V zNr0p0`|a5@lHkBab4$W!4QoxrXoO{r+sN*-_D6h(JHgg)5)cayucuXNpE_^oc5VMp zF&FCP9u1l_(N{ujgc*D)KoVo`XUXce$8-l^?~=yVg;OaeR0z(-o4)w;Zids7JDe`3 z4`yY%W};V2{KVAG$Qy<0ko@%lnAAEy;YEnc;nr0C2R}JF@r+3Xg8f=1glsf4Bx5kD z;M|tDcZ6&zB|Cqr|0Lo6&B=SZV=>4=SmbxNWk`&(jXu;QH(>h+2z!d}>aen}fG;>G z18_Z=AqrBBnE7R1dp1K(-_9QqV}mc2uk!J#Bc)WMzqtyY!JwLI`gJUYO}Tl!5_NB? z*^gYm{Ss3Wojky4E3`X`22c72y~z9Ar>QP^nOuEj*d`RKiJv{&?_HcWj}pJ!b_tBy zDsLU5g`+%Q995=vFrabSsGB56#C-x>AFZZo0FIYk0MrA%-Dv5gzk)Rqu(U_BzVU+) zZv%{!o%HB_sg59l{6YQM2QBE;c4`bOzA{X!hS&*qp?Xj?EEC2kDKy^hhKNcvv>_1Y zoO?GhR%*+V4nlH1!I( zM7=yX2R5u`{jQ`gr|3o3ZV>LHQg=Ym1|Kq%6kccXTZZJa=E*m#tZ=_Dj~WyderJy7 z58-UeOtX!y2RN4U=^^5XvDu2NJd-UzR!`4H@%L)$!(gD+`)+G?{R$1}ZH*t;OS5Jh#^#3+a@Du>(X*VDqlXd4f4Gqo5b&Bo<(0-O;pi914YP>!7XI09r_wCxwSV{wDoEhxS)8T)Of8ztps8_SgV9BsQy>v4_0O&wo8S@ArakBkmBYc~l)K;ir&+Qxw=P3OrpgD5pC6aA0DP??{c*$ytVZA<3*_Gz zK4N$Xvh=pvHJKJ1+_u^TrQ&LJ$AqZE8$;EBvmL~^&!Axp@xGJ{ zsm>dlMo6d`uu39EG`;=51rJQ2{fktaU~27I0v1n)vehBq94G+}%!3HX@#{#if5UrY z`6jRcivPk3>E#I3PH-2;IhiNI4`uD{0GO+HU*Z~x14mLSs*?4-7%p)v9DoU4Tf=El zQ+^+NQ^Co8XXx5?V6e%o7}GTo%}T6qlC0=jx;X+Mt8DEML7YD=6GVF?(fhWGaVf)L zYGoBF{witlk{xgn8hm`_?hN-zB75VnLHK74K~u)Pdjps`if#N}KgbCnA?-yO)i!5j zX!}Bdo#B+v`-c{ud2>ftVupvp~&fj06&yJr{SSDUYUQ>5a| zZlBYNG|L&OsHh0X4TaZ=)SIXTWqMXj69aQma>4#D=sgHAopnb%5LSxH^(z?py|>;D z$jp}{H9i_wB$FJDjL9Vay_mGB_wKQwsC=eU>`Mz~uHAq?gI4{!_9tud$ZsS@_J7Hf zg}a<7>O#|l$eaCzQ|f1z436uW*C;~;05ycQ>j5K^B|9Vn8fg39_ zRGS`rK_ZjM_P;JeleIclyLqC2B!fQ1ZU;-;wtE+QQk1B(V1`J#bOebO!)RzKTNH#9 z3J9+@l5knKU7Y&}Ll9He=Lw(xf|vvc{%S}R(>=xywxYYE<^jLS>2{wyp5B;{+FwuU z)hl(e>kel|c1_Ml4zf}g!Nt=Z0Z7n~mq*8^$)yKQbS4x?_czB5`)f>a4W|-Swk?r( z!=`o%zO9-K!p8Ej-xZs15#l#?ZyX^v8gy73iIlV>`DnTb9%CIVR0cJtB@Z*_Hmia< zg-Lc_~FW$Co=JfIlJ-SD+`*&>q~otAS_QtH}~C z)L)0w+W3BoGPyiQvpwbE%cF$>u5Ws^wg^h}n8}}6KgFI8ItHT=#ZF)VV^fTZV^^O% z*taZ)cgIc`!r8L9ZI`P-*a$boGiGGyV!~qa#-D$DL^sJP{pfed{GHK&T?=ycYC#e| zn60ES6l6<=kOviC?GP(ul{0#6(1ZdhA>75p=hTe4=!}ezTTHqXwtIM3Fl+;kfI}J zSab)4(?=?IF}e^fOD|0jjvXqW&Z8U3`Pu6g6OBv|rEvT?l2APSCKUp>?8(=^*}>5aNSTZ-d|jyW(#j-52uYx^zQaz~ z!ksMNQ~vq3g%eA@8W=fCWS%AyTxT}HvxLKG&feSCo2%B)b$^&n&gO5+R-izLnmz z7Nu=LV1v&deG1tR7v4X$-hfxSfK(>E7#T=q@eZI6x;FI-p{^eGw)4! zG|&8rtAr(+wJ{edPcdbzha!HR*$X(PMySt(IE=%}(I=0z9IT&?rhtMaU% zG9I^kV&ALThW)7~0z;oR{5d;AP(T0_01d$%osx-qkVajO+ z8*oU!U`k1>Y6S_@aMTfk3N(@!T;I6gdWLVkv{sPhgnP#nx?jW)3~r1dS9*nWAZOhT znQZ;cF&EeUgKY1bjp8jx?M)yg2Nb$>IBWA|u>?MmcUXFUuy#ZJc zPA5wm@yKqj>S)8stmzSae!EAI+?|7<1rw=EFjrUS8%spe!yAo_Jc4{;4u=zsgUK@!#T;1Mh?9zQD%*qZ&x_Y zuFn87`smTNr36Tv)KKK!>VI?NcYw8co?l_oAyF?B>-qn4Ss9alYRP3Dz~dtILRo?$@bTiW>cVkX)01BUy+uo?jp7jiz| zNB(tY5OHR&pLhiY#BI0Kh-?S!Tqh9m`&9Wn!7Ow$n&|C2OvB^!755$=EA{Tw=W!^G ze7ZQX@a9i6!OcRiRML;c%^xOq`nC7stH24sXP!8BuSZ05qN3Vs&T`!*-LqCMr7N&s zhklOH@>#5a!+a{xOSI>-xd%_Of?0e0N#>l9np)gUQs~!?oYt4!QV!13jq2giG3Sn? zHLOZGI|zA2|Ks*gR;_cNb&Bc`vn@`I)w7$dZqha#I|1{LLPdizvZ&hYt$d=7&tx}| zXT=Ywd&L$<$vcr`Zv!YzTB{GS?g7->Fno;|Jw!1Z7p{p5}r3Q~IR4naBJS#)} zvYxQGB8$>wqL+2Q$Bka(Lo9;A?%YovqcOzXhqG@AGkOl=ZhQ4a*!4OFA(haA=u7E^S`suYhI0saXZz%KTrW2HYNJkjK)W+6yM6XKpjv zXUty5PCs+M7qXr~ayyx-tF=`2OjXRhm!&>SUspwYwGhi~=XvG^OElV%#bv&&l5_1G zfEgG|ij#rYu@z@lOek`HEJL0$<>(ar=ir-ht&}ucVh!p%_UK~Xrp#%}SCuW61$yhI zgFLK`nva?vPCb&~&~8r_3&k2%PBgh$o{scy8im_A%fHQ+e*exn=C|k9CTLe*8aLDE zM2&}sU+j=&*!+0NeQGG+dJffgvG+C7p71Es%KvUTJvntx~7PkXh}z@mo8e zySy2_{GWul;9f>&(J`oGuz*2uoeUQ zKeNj^!nonForivzIGw0-x^hqH@21PKb#6Yk1P17S1QRo7#I$W_Y8R~zWiff(t=07L z*yXaaoLb+U&KjtE&`n{juxT1cMNO)`l~I_Uq(pAMEMf9<5LmK2sWlCIcv!0cprcoRs&1*7+}+bX zBhVi9L=EieFCRbKr9`$~=}pP)TkF8Q(HhEp3#~~YkT0fT2iTTipdh4@S*u!P| zwR}_l^ScIasTBc5BQpLHBcZECz4usTn=ZqH!*j2Gh}U7om5mKZM!*w`RewlQdlW>} zcF!KS?N0v`oWtw1cYD~2jnx%R1(2w!v$E7Vtp{PfA<51FFcmHmI`S-m5s72sn2wIl z-spuxX=E8oS3vOVp1jqTk`u&X1i4(ldRXC$-csh^ptOhJlgBc*!%Eq9M$1g82dmq` z(!+836*pnlx8xiy85_6VYeMwa_hKi@by|jjtbXxz7Yo&9vbuHs@PuXfVk4T>_HR|n zPYqj3zrU8ol8cj4aAaa z7?)rn7 zU(f|wiDVH71=|kXk}fp~>8yIYy4;W{W(eko4ampi*}gvDCUV*?h`i@y9z97+Pt5C) ziv6-1t+Ui_7gnJ71fq-f(I(_et00Z(rgGgMH1TWRRm%|cl7$ZmvG3E?*Y|NdUDBDa zjZ@RK=k5wZa6Ql>WUKP`Zd+=$2&+F?*g5oVTq3}GBqzej_4o5BwYg()uP$bwms<8X zb}vP;3(2s|$f6CNTf4mjUkQ=Q4Xipjp9s`BpaNcah8to~)8k-c6R>RF1gRs>kim2_~{Q zAH~L-WU+Re3e&1iuvVPZ!Pf)(`~cFUGHADKDZ89vv`U?rEX!6t|C$(xrlL$P{2 z3ks$1+modeLt|;ViF0_pbnYn+es@PE(R9wZ^5;yY&TMFd8&~-Mlm$c;q6YA2D-w-6 zto6vJ2@P-K#XnD;sO!wTyCQL9Z?jnI!W~9ZB0qHa(|+dcN&kMFcp#Ur=I(;KrYX>6 zm^yJz+VjSp3;bXvpzK~_7M9jR%Cx+r|7zYx^m4Uo z=5|s?TFY_(8}hlHDN&DK{m)mwv+*ewKef%WV(%}{3x6VrJOl-j>n224xbe&r*_zs4 z)vK}DzBG!k5=?WX3)T#E_LH@&4I6@Y;#7-Bcs?4;S|(0^qg2rh(mrS+w4(2Sd6M9c zR8J*h(n_=62?W+~FQ5bzn9PMLTM$?@YH+@af`Gq;|IF&WSG5$|U7UrxO{+R*@}E@&`=WPqfqr21~~-LL{ZUj*6kbFdV{ioshuIG9R6NhOoWi+cN; z7U4ttkl^w&ox4H7P64p=L-HYm{=bjcC^qJ4#t=36J(hd+KVMV=ELSr_1|6Og-79hk9{)i@Mh)!*`~ zF~!Q8V3L~u@AmD_&`n@Y8({ra2G5&#jJIR1UP$AYiDGYTvXjeW}hwX%fUI9(pE4rj*x zh|Y}vB8LVAJkL$;BQP76!YO9(pB#!VpnwF$TanF5oNfjfQ8xV zdLOx5``a(CE0M`T?4IjCs1wC@VZ84<4%gq{-M$BO!GFA<{J%^3ySQw7cqG6~Ina;S zN5*N3RxbG zwbJ{J4_|yg3((HR2QA`**X#@STla5ML1)^zs$lQw>%%&(uU;0o@9MN^i!O5M-8cOc zX8ZGq|CN=|>G!odPF>QO`@Sai|J;UyM>9S~qWYr`WTs$^t#z*FtVN62ziVq|HkZ}! zJMSwM!8LWpjEqP__tmWzHu*}Q{IoOr;pM1XK;=Rz?T!nkv996xvM{-Qt6ApUN#$zu zeOF)6(eirx=+^4^|NCB5*1@#j7)4_YZ|o7)~npT{O|7Wn7@jv zGeaUCc*l!LOFw?)vO|hrH|j^2*cG=82l)IBC>+?QVIq!5M2sv=!3_76P($Ycie9t-fW z-n|+R_CJqatMost=)C?WDffT7*tTA=XJXsRaxxK?RKfN&dTTlyco6goxRXObA)$cL zqXvBM`Ee!|o>w0qTv@(M&|roCV#jDyO~A|PB|cnTkqq=^sSU#%IiOK17BVn4OEIyX ze#gSol)B`lI`Tp6EL(t*4RZV1E>tkF@g|DAWM<-E*$fm*ko(*H!kV$U`@o~c9H1M~ zVT400r+~o*SPRxmE$Ww4U$fT?Phbb59q?5O~hz2m<=W(dxt3<+p)8 zKOX_P!QKjJo(>@=mP2 zp2);@wK!QHCGkDj#3f*`W6hm8Z*BvX{Rjj3w^7&OfWn63?>B9o*Kja(@iP}>D;gZW z0<&Gn{XW}lSD-)E8WahD^5Qv=2PJD0=FVtvOz0I~i+l}Oqk^bHg1{2y`1)PG2NwM6 zf!ut(O4;GS0;?>?9sG>V3Nj_TdXPh|Q61=n569LtZUvghEFG|`nSsd>B=(|hPnE1% zLV&ILHz#;NIRIIDs~xI#O_`)rxp!;*t!I-Csfb?xx6m@_u%JT1g@R*~<|#Bd6ac-- zajtOxi@4v+Z0UMc=Rdy>2vKo$_YMra|7*?i9Xn>YU;Hfmi4!G&4yaF#e7zt zS*DYflz}OODUPob*?*`ZdVqVwA%zF+6Znu*8>;dEQ1&%7f5L&H6GaWn9#Ddocd|lC zvM36GL2wpSWZs#ejN2lC51=Ge$MF(PD>BzqTp=N#*trjnMF+TnVf{mK2{Uf@UWoyw qpR>G{pnQbGBu0H;loJY=#{cYRmieVWm@VVV00f?{elF{r5}E+)LduK) literal 0 HcmV?d00001 diff --git a/docs/examples/fp8_primer.ipynb b/docs/examples/fp8_primer.ipynb index b8a63dabff..788d6c37ae 100644 --- a/docs/examples/fp8_primer.ipynb +++ b/docs/examples/fp8_primer.ipynb @@ -18,7 +18,7 @@ "* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n", "* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n", "\n", - "
    \n", + "
    \n", "\n", "
    Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.
    \n", "
    \n", @@ -56,6 +56,50 @@ "As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration." ] }, + { + "cell_type": "markdown", + "id": "f03b58ed-71e8-422a-95be-35c1cc60c4e2", + "metadata": {}, + "source": [ + "## MXFP8 and block scaling\n", + "\n", + "NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: [MXFP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). \n", + "\n", + "### MXFP8 vs FP8\n", + "\n", + "The main difference between \"regular\" FP8 and MXFP8 lies in the granularity of the scaling. In FP8, each tensor has a single FP32 scaling factor, so all values in the tensor need to \"fit\" within the dynamic range of the FP8 datatype. This requires using the less precise E5M2 format to represent some tensors in the network (like gradients).\n", + "\n", + "MXFP8 addresses this by assigning a different scaling factor to each block of 32 [consecutive](#handling-transposes) values. This allows all values to be represented with the E4M3 datatype.\n", + "\n", + "
    \n", + "\n", + "
    Figure 4: MXFP8 uses multiple scaling factors for a single tensor. The picture shows only 4 values per block for simplicity, but real MXFP8 has 32 values per block.
    \n", + "
    \n", + "\n", + "
    \n", + "\n", + "
    Figure 5: Due to multiple scaling factors, tensor's dynamic range requirements are reduced and so E4M3 format can be used as far fewer elements get saturated to 0.
    \n", + "
    \n", + "\n", + "The second difference is the datatype used to store the scaling factors. FP8 uses FP32 (E8M23) while MXFP8 uses an 8-bit representation of a power of 2 (E8M0).\n", + "\n", + "
    \n", + "\n", + "
    Figure 6: Structure of the E8M0 datatype used for storing scaling factors in MXFP8.
    \n", + "
    \n", + "\n", + "### Handling transposes\n", + "\n", + "The forward and backward passes of linear layers involve multiple matrix multiplications with different reduction dimensions. Blackwell Tensor Cores require MXFP8 data to be \"consecutive\" over the reduction dimension, so MXFP8 training uses non-transposed and transposed MXFP8 tensors at different points. However, while transposing FP8 data is numerically trivial, transposing MXFP8 data requires requantization.\n", + "\n", + "To avoid loss of precision connected with this double quantization, Transformer Engine creates both regular and transposed copies of the tensor from the original high precision input.\n", + "\n", + "
    \n", + "\n", + "
    Figure 7: Linear layer in MXFP8. Calculating both forward and backward pass requires tensors quantized in both directions.
    \n", + "
    " + ] + }, { "cell_type": "markdown", "id": "cf5e0b0d", @@ -63,11 +107,12 @@ "source": [ "## Using FP8 with Transformer Engine\n", "\n", - "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using delayed scaling strategy.\n", + "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n", "\n", "### FP8 recipe\n", "\n", - "[DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from `transformer_engine.common.recipe` module stores all of the required options for FP8 training - length of the amax history to use for scaling factor computation, FP8 data format etc." + "The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", + "Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training." ] }, { @@ -77,10 +122,12 @@ "metadata": {}, "outputs": [], "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling\n", + "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n", "\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")" + "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", + "mxfp8_format = Format.E4M3 # E4M3 used everywhere\n", + "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)" ] }, { @@ -341,7 +388,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/docs/examples/linear_mxfp8.png b/docs/examples/linear_mxfp8.png new file mode 100644 index 0000000000000000000000000000000000000000..343473283524d138c12123332530f07485d97ad4 GIT binary patch literal 49282 zcmeFacUY6#)-P%S1+jnvQluz~6p;YZt4k?Tr7JZ8(wp=e1gscE5CjASsUlUXln@XB z=@LMc5)`C_7O4S3$enM{rG9tseZF&_d!KX8ebzr~uf@!K=a^%TK7V8G`h0J;rLtqku6~+5;4e4zlv}|6?DW)A{cA@-=P}%l9o#!ET|95#XE95n zjXz}=thmK>bx-_xdTPnXsV`sNqnmnnYH4pc&BZjB)kO{lWl}NC*)s@bM#dYo2j6jj z!|zJ_YsYc!!wkxd=+xsZ6`I`p*yt}x?385wyeaRkC0^k+G8jBRr(-x=xiUT3awLm5 zq|@i0EOvb*P&+x`eahByw&E#EDjIrjrJFl;?)vdBq0IZ_F2XQxb^EXX{L_!IR32^? z5r+Tz^N;&zR7_--Wwd|$1~h?2rTy=3{`^4MuCS=O7a~{WBC#iaW3G}0H5Jdw!5a9< z52mog_i`A2=a7j2H%pH17eEOXd|303&3@3^S+9oLEzCl@%=~MMdg}dV#I{nuyXHN- z>~1kdr;o{22r;Q5RiI~r3Z=3mj|IR3=1a=oD`Z;< z#}enO6_fp@6L236arNYs!B_fvb_SwFP}AcfPb^B8y{{$~`VdHuE9PPk29Z@|CXIub zRutA>^I(e2EhVgN$5+!g;UdhcoJ%?!{l%+J77G=)F)leTN6$}55=Z13y&pn)GQrQy zl5qOVW3Hpf{j?IMo^OS4Q{Nb483W1`VIeI_lnIsrJ3s0{AWi?s?m;d zIVR)hrhK(g`WY2?-ywfvlMPHM|8`_}g2ex26snR7u{N zJ0Q`TB0u_dMvFL37RwG2;VA7&%u3cr8=m$|*7xPA9%=F%E_Br#oZTGg91Ofes}fOr zRM1L+d_O(Q$bG+z$^vN!!I*1ev7SYe{i@^py+__h+GSH{+i-Z&(f~3eNR>dp^?&L& z?-h7K|B8>|1V;GXR2Gip&-&MpCmyw1(r50dMNEBt;!~3OMth58Qh9OO=2djFTfp5M zImGLTj35KM0#5OU4h%Ccb$Z=;qb`$)WnFx5d{%lyI~2wWgx_ z_h!M!fS9b5z9JIKxl`5Q8H;mDC>PHG&F?Y&;*buo=qbTOPkxe1q*Ixo{Y1L)yCDdR zN~*wkp}(DwCnPpV8$sTYODrS3c`2Y9wp)xJp3nP z(Up-H#s;(1*Mtr+PgF;iZ-iF`K3Gs|eNW1~_5Q&~t0AAjHd)NwgZ`7H&9+qyxub79 z%UBto|7()?SZJvz9#e>cg_nh2t<|#T0U7Hu8JAVoZxK_lu`Z4qB>Q|us!9RYq#0F4 zyja6yB|E*iN=63lTf3MGSIvK#joP5C+c|!Qu5PHx&YVDTRb4lN@mzEPX}nOTuu*6g z@Aod?kOrFAz?CHt!EsKjbxh79e5I$OVpC@L#-w?VBU8#uFi!`JG=)gKyl~~i6UpGs zUa6SSHi2BZlPGwyslf(mN(`k>(UtUTaa0exCxS>SDDe4qQQzhgnUFDKW*UR?IJdm#$#mp&d1BzNmhg8@Bqf4aD;=}&0_u3QZ$(_{FvBao;5 zM2V!>{*0tx8DOp&tb-PE{LYl<=HZ{w&9hMfLvvQT$r&NXVqRFIGDqL^V&i z-KVe7k?`M6?$1!5XYr)eaZdiM~1=h;+%6Qm2N7%xd}g+WN>~L)dr}b zO(!$|UA}n;l1sMqE`=ia$wc)Fz<=EW6urH~bW%wPhFLpCkEA*n3RSprny9FF9(LPt zU4cbwWgZ8j_ZEa@Uo-X{1{t+O&_NoNU3WpIcrVAm8U$laQzeJau(V`h9s^pxf>UtO(tzsMttSF% zHvxkeq`{xd2&XMtY>? z=*#INp$kCGdu4ff_VUD;d0ZOs8FZ07-#ps*AUoNc!N$3)&j`0R7O7J`S8jS3{O;us zGmYBoU3#CvJ`h$cx<7-7N3elUT6Gr{uOTR+V2WU?Z2$}oXkt7yKQ z`J~4)rA&D;(3iXr>9BMvqcow5kUQ(x>?ZT<2s zKd`poNJfmyR@ZsF2D|)(E-tT&Q?l&R^m#m)xI7n9eI|I7OCbrWjyuG6F(&3SPxhnM z9oHufjm@tM%<^@v$Lk=ULQeTeDF#YNJ9`1T*}P=68HkvVb>TP31%eY zi}$A$*IRNP8a|-6!J`Pz%k3^LpyUZdVnQaR_9oS%f!WUk9PP-_H#Wd2_t;Ct=~BNs zs8Y52vBxPD->J*beOcYT?lKL}MZ_LgB(3-1G%QIf8=I%MYD=Os%#;^5jD@uB*0>aE z*NL>-MSRF4@!p;urJX`mW3aPgyEXsAi#NbeN@^~tIPjmk8I>H zW3HsdqOV)LNt9@-T|ju0N2A&TiS@i^9TRSiZS@i(bQG>shC@n5o=K zOS1f>vvNzGIKiW4F%ciq{PHqs{(7W{OVNMbY6A?0PuPQNianr+<{eRJMct+tG@a2I z5*LUZy3w30wyMH^_Z1=P#rxxcwtvSy1D^W#Z*SA-&!?XYi~4%^qp8Ck$))>ysh`oP zR1tRacV+lCsu8S3p(`JU!tg^gw@ylqEY5U>ge>%M=@gFO&67K>w5eLCthcAaGvo?A zV&BUXu(D-Q&S-dctMnka(Gs~pL;vyuQ1rWt;m!6ZS*{kr&zn?vImi%( zJhMA9*07C7Et8mqE#y{Eh{BmNE1P-R6D$dXY7gpndGFz=3`BYr4)94B1(QbRqFqrW z$>bdwG>XqrbDcPSF4+^P@j8O`oTT1r(NOM5f|w^YkMi)MZo-m|?D4tRc-7vQeUx zT(5Sc#_GD-Lisan|I#X=aQ{h*jRDDQs1zK4^BMG7UEvEyp5R;W5=4m3iC3iuWl6HE ze~l2t$YKdEa+*B@Lbe(`I}>kgDZ67x2~(A)Sne|$Mk*oe!umdH;)l09agl`K*hgCv zraj4ot>anl?y++zRcGeD_g=>(F{P*voTXGzb=|)Gp(HtAKrp1uJ@XT??2}Ul z(dkpyvZ;vPN8uUf$e9>o{IV<=o1)q!RIgS#we~Ua+Xb$;EqB74HIZ($$b(&e_Pzr; zKQ7yHYr+oq$&Cxv+st!j!gDj}NwSdpTANn6jz&tQj!Drp8!l9FIHfkAak57e5fGsL zIchM5L?R7W=+HT~?wFiV-1L~b!Ln|{I~5)lMa!mXejL^WG&x{zb`(5gWFl9&B5>p{ z9u`&?QB4on;#%_=&BBpuW6MdW3*$asDN^4%RX#HOqyV*E;I-IasG+#E?xuQSDThDw zl@XHr5pZVTt}>%H!z-SC`iTXYtNxbmFt#Gmwx#UVXK18R7~FXt_nF;gSo%HMa;eN~ z;LxKMS-5+y5GOHgxoHw5JnT!B$6YvD!#|9e?5J0Li5_}ce=(-%P_rVJ3=E~yKCrsA?AJ@XVh8rS4`4b z$0hWxFb4h59icKn`n>rR{g6@NQR_*z=(_ZmzAJsk+4=W0n!`g_;=3i0!s1nYkClIex{N_^<@oybk@;UFScgN{v^Ay@Px{~0mN z6A6=;gNnG*I82)xonq4)Mn)vEO>Muzp?WneODtKSc9m_$N8#}iSPc4RoNBxdOm8dA zY|}%gRXAn5awwqrp*87>3f6wJvM&{GbvBpp0dPs59VJ($TAl0F9Ar|Y71}@P=$Y2P zKlw1LmY&5D5=84BHJC@sskn?&8c#Fkn{8XQ4AJWyZa%xB4@i7UwYYyXomOQR{oWYA zuRUB*2SXb4WfbKXYAiGKg*V4b!R@Kn@bX0R)(MV<^%QV1DZM`|%753#@8BAS(}NY6 zS2=n-D9ZZJ)$f!%w=odhg|#v_M!5jI#!2k}FphLjwi8$cHOOjYlpG%Hf#8qU5G$ zjh?U}+nK_^{hMgNwA~>`I`(}kCQ?&XnyEWOCHC@UBSQQIw(7rrrQhcgF^Dw4d`j-jH^Mzo5UGR9C_we4vMYS@)j}Kq0%vRKP|#< zSQ!^hmaZtNXi4v1e_$Xd<2ls0H(QzRW5R1)hxG(Q%mv$7Yh4Als3Tj9&6}%To!YVr zE(C+1g0+nf^NI}pt+AKOC8k?q^0vuZGJ%&FC-sBOz1t!&m*NoaO}s@V93>-;w-NR4 zVMfD=>tlnWJ2yeR1a#6IVIW$Cl1VK51WB@(8M)G$9~hJ*{uh2j&y1uK*{{_wa_qOsW$(W_sc;5%^E%(G9M zUcWb>`{_)!ox-(YZ4?uEsaiX6?>t%GZlyKDWleyk<(buztHBUUR zVGB{%?0c8$^>ni6+*aIOFU0zpw~tNnH6r=d6FeB}{@}mt!uv(N{5zWlvxK-Rv05`!g>n7Ei0c7E0Xp-g)pMPR7#XL$ep>kr;BdRz zEN!bp95|OB_!k=BMILl0g#;`e!UQF0ZJSfaW2fQ808W|2Yt=g(k)QJq3 zYDx}yjjQOtbP*mg)@^j(p%xHIV}eei(5@)*|G2p zdF%Xas%Ieek$Y*WswCf)@CK{X)!_#c%n$^pZ-Hg|oJ2g1P#6;3=EbUi#V;avazd|a*d;VP+*!SC0JXIWa<-(rhW_p{lWy6A45 zke@o*bb{teF|G~MAF|k$?4G_kO%UGnBw3kniJDWLAR%sWEPuL0$SNM6elirYj#sc{ z$V>V@+#r*y)aogO{nEqE7t?I|20q7{>vAG2YSgdNh!hs}wy;`Q8Fk+6HBNkP-WR{c zB`xUQx{#7Pm%4x>kmeWW4h}|GuM8fwxor8sP?30M-CjU~i<1%A?v3>i=^5yEEQh}% zYnQtzOo{B=wL)va`WV#1*qfswOE&~^vnS=Jsd37YpC#nI>Mu;@0S~OXobm9ZcJpkt za@F*PB8|!(7_G|0rl65ScoG^P7FBb6sv;t7L^zFv;OAPk#1L<@RHA1IbdC|6o#ndr z1XGi(kuLPP$u0iOPtJlTJvc{Lc;(&)$Jo}^tD&1O?xE%+Y3^N-m2{pyXthVa>D)^r zv~`>{aYQm$SbjyfE4h3+)AF(I$9aO(2j{V5&+R>^!!WfX{nopvnl7YXsW;hlvtxLC zG2{RT@?uKe$-W1O7~aXF83;sJ)Lz4gV3p^jjJuiWpn-s7w}PtH3m79E#1#Q!B;&{o zA|~sTafaVahaq~i;hYJ;PvP9A{gyKcy3ej{)zYafHWgXIF!-{<5>m08J^sE`O4Mn$ z?|I9m&1dG5+64|>En8wvoVP@JhDXgC(htig&KwXJ4)pw16Cow(jO(BLZ1kaK!5=?U zo#SF&|DGF-2tBXZgIGFA=+5w{v89WRkO0vSLPa@0PHwy}nPlPP--LxCK=GBT39L+| zr}GC+9wX#8bBzE~{}@{kgPu@=OYMxr$W0ku=hov?hKUYsCY5;`z(m`H)t+!78Af z;-5tgs>7lMGdXtdT4JC_W^{7e8RpaDIN0mo2iZDx*-f$|le|}09m))ix)3(DA98ds z09g__Ud!E36PvYPKcXvWwlqbH`WPw_v&{C?t(ZS7dS1%}K>R{A$-Vwd8hGJHfZOw0B09wSZha!MeNy+Ct`Ty1o>uSPvcd#LQYz%r5xW{(mb0Wo zz1lojkc_Htf<<;O6N3s#&MQ@uCJR1kRtn5+2V~~vlxHBea&Bh$01I9tw6kp6X?NHV7(KQ3Vysh*Hmfb8+pQLV`FZE#O#r+1`LigJKk_uF zII~HGo&`Y{){BJ^>4jb__e;htagim(uoj|Iu`e6Zp9q2DZo~CzgmrVo7&?3O#?$lc z&OT1#v!!k~k&Tky!DuHp74nD!0Mh9tCZ9W^xw8m~Aft_L2 z3KViG^hYnCK;qnIb`zs_^7R@xl_9x~tb=qEipW!?E{7 zO<`hreW{YOj3g4sTJTi+QP;Yin}wS<`$N_YO+x0HJ89Z%6TLm>E{U3sw9;q$U7UbI zPYM4u{Va=-K~jn3$c+moA>G!tH3JnRwbG51ohu*H(>jU-trxc@v=D;ze#C?wE(t)97dlw2ucob1A+zF;yuuWDx-nCTh|#7ek-n72`- zwDCEMQ$el$fVH}G{YBl=&9Cyj2`FE$M?-xJ{zjE z9Wsk7ZD&%7S;`nWHhpkmqoZ%WPh{BVcD>tt)v4iNe6queVSoHUUD<4#*`%Ct%HSq| zWG1Rwj?ng=23X#%Dk)$@C`e!FR09TNjYnd0)@eDPZi5Gd2w$D5)ro3MizV81;8kRQ zjbHlmGZ!vNraO%F&DotOUXHa_r?X~UD|45su%o$GToyMtB%TkvEThqnsOUMiemf;g zamah5RLsBM>G?UPPR46(3@K!Xs^!JOS9rgNs-R4IPk!_T6_39=$SFLjX-nj-iiY(q zk}YaJyXs$i-+@M-4fb!e!?>EF#D+KTT4q_Jr~ z4t^GXFOSa9&Cf+|SnZs`_JsqMU?K;Cy|~{)-VxseZW{uZ2}zVZ9u%Z*ot2W$CI3yK#u# zOnXUV7>1*DS)sL;WAwm)YL*ni>b9nISA@t~CwbFJ#;!3tTf%b=_~e4-kj1U1WhEd= zN**idr$=V@4UAvpva7Z!e)lubY z5@AHq-0$mchh`36&598O(I8RdY~@LztHn02<~n8ZNu3JBS%+QW1N#TWGB&p=C!Y`I zl-PHtc|lq4%16^e{PUBIUF8Fgxt;#8mQY=Bz@=h1Yvp)$A*c+TE=)HBK>3^|@#0tT zlY093v4-%1$Jb*J9Zc|cBhLcK9;AnYA;Px2!w00~CwrNZns520cZV)e;Myun`&3qk zJiAKa^vCxq=MvJ z)l7j{*dtNKtKNNv*}I;mH7$;}^|0C{iKGXkfk$gEK6L$kas3|ljK#p=;}HPZ-7;g} z2k!4)-FTQMz;eIW!09^Ey@M(f1E+1yX3vw9T2Yo08=6lPh5t&mb6qpWz0-)C%7ZV) z3Il9*1txk&Xg9c3Odou>*y6vJ2bC)r@ntma!P#d4`kJiVUUyHCb|FWH z@M!+vs;P(vVc0jjLt^j@>gxfw_;Sb0H8VHji4Xi($DTq2T45GN!0PY3tXnwp8hlvu zdP-t6Gv)Z2Y_px}`wXSOwsN|VF6 zB)ZTR3N*hQurZ>KL;0R=UMe3t!aZH~;uOiVyTFXYV{@TdC(tJIO8)D)@en+1?bn? zf_qgm#$P>A#QOFlJG7fy{IAoxH?`yu&8T?BWJEh3)Y*VL+i8w26%cj>%eT75MNY(k zbM(_IXD@A5TUfGcYzF1_VFwxe$S6@3oH1H9m7kMg+M-48juIqEuhRP{b|j#e2VfCn zm!zS}B&gnthFo9*usu`oZxy;W57YtWeM}8`EV%hAXnEDD5Cv11Ub&zX5F>e9t6-sK z;9Qw;)z-$0te{>V%l#x3N3Cdp+S^pshly@kf7?$>A8!U&{q0fya|V{vltP_IAw0u( zTpVW#o#z6XCu2?z=(l5lU z#wKPxBf*t&Eanj?SuTJ-_~|o&B>;9k$JeMG)Rvcxp3NV)O3Y{Bulj_~KdXFip}Yd+ z($thcyb|$c{KcVQdzs}A7l^gjvxrN>UoSBr&zP1w3$R>$<9=<#bIR8OT1!vHqy&80 zmgM;e0G_;PaU_HLa37Do0BTSd1BkGp5Fw~CeHrnc<-dV7NR0&Zu@1-o z8vurXE?NIUPW%Hv`3Hak^+Eii}}4s!OgsX|3-;h z6(9oA8T$b)-!1{dT`|QTMjBA%o`G=|2~!+=a5j*4CchdaC7WVUyQ4vOg1~PK zQ8Ie0{EO2E1y__xmLd2N7#IODQ5rP^_YLvPD@W>Q#NmBYtf+3wZvyK|vLjf+1V zpsjKy=u0tmMG48{{0B1&lNNEyZyQZJr4m7UbX@hUPMl?sg^A}waRC|A>j0DnySfWq zHMXIBm$TZwp=CNAyWK(n2hJTfb!_~H$8L{Mr&SRMGkgm6%epJ2dhwlhZolICL)Lup zZ8_f4oH?lkAuU%0H;33grwqrdTLvY*d$oiv5`g@E0NV~AKK5lw^Wty%)MRuhN~;IL`G9Ny<)jo2fE-_Q zZS@`YM0=4*hxP}>WQ|x$FeR6R!EMKkOxut zp5FaHrxNBCbSN~RtTR?2Am&M^+Gb?Xs^!0Js<_8}ke)K};qJIC&} z7Vov*IGHKle3^9eJNCyDN%b!A0mJ6i1Q@N)?Rau8}SRjSKQ6RZM z*SUeQy@TRUaYm0PCkwErzqa>OcZ_nh5q})5Fjr>fSZDQ~=l22V;ODH${xA2b zKZsDa^cduuEw}bL^0Lr|ooNMg|2#Cn{VPhzJ6}=GO*0QlYF+@t^<(?u?LauQQy~vw zA_#uE_QR3nmGtYQ3oFpqBfWaB?XCEz*9`3kUQBU4g@355pbRx69v@6XUpl#0KKs_u zW^wA6!<;kb1|Uzf99nSU*MflGuC%_&yP+lXnSo$Wawj7?{c(?Z;6@>FEo7kke zx}hwWiYD(IuuB1V)qix+<-IhR*Qz`<#B|^$(793*V0@Cozr{{Vs}E2#qm}Zmz_(ul zt~?YI_o781u!eW;^&lpq(*P27;7s0$`z;zQ-=nP24w-&g!QUng$C$n~@7}?$akK6s z9n#~#*6((mpS@5cWy&GV6O!h~dJy>Xz^}OvJ@m1G0^<=;rf${4C38Q#S5{q7|CbZD zEXun=gGD}?`X+mQUboBAlEB)FzFwxh3t=Jz8J4vlFCuJP%}M9l4cs`_GM|?6SO(zM z6Mmr0X9Fzne1I%F4%xM7ZQ@d9n0>i5#f7xlh*cL$&{r-5RDB2l(uM*_`u2kk+g$4~ z7W~ZyD+%nL{@dBRFR5AQW=D4dD*=a+{FQ^R|2W{i$zjm!W8jqI{UAde7Ii=6pbUV= z@7z^F!R~YZ^5l0a`BUZ?H#M#+?}mszA_SWi_XBt28iWyYU7f1KhP@ zh$?tug7S6+WvlCEfk?(|0CO%^(=^OOm7txyrnxBZtq z0RcAla-?{%QEC;`o>Ca=`<3NBQn+0Kre< z(SVE4(ThQH*17LTi!ESa12*Q6%tHYBJxg+LWU}~d+ARgqZ?mtRz!tP2J(+o(cueZ{ zSX5I>(&+%oyC+z6eFEgCBDVdM_y{onr4flsAV6{T_x-6euo0&sxQ*!KZ3}4Ub(%Xt z1t=MSeJJml0E;{U?EiD4*H+&=vr_;l(mn}Takwb;9l+}We9NFcuP0)P@<@mrA9OYn zf7s>}r=`BpWd+n``0B5zz}+b0t49vnT1K@b9onWfoU+}6zqVT`Iqr!n=r?r`VB?1L z%06_+$=wa6x9JbS&X?knLu?0Z#vbwxs7*p*epBf8#89L{BKZft<31uSXOfP;7GBm%ydz>po|7ZgA<^5D}N}l zT{@LwE*}Hq)Z!W82=3FPHTeN-5X2eKA)xiAurwbaNj?cl@)gKK7|muJ8Lu%>)g?1$?V17%&p#X|`{#S|?F0nbIkiY06PF{>H1c@n|Fhu7Wh`lx|CT+34dMv1i) zqIKXv0xgvDq5pvFDs*3C{D(ohPE*{)S$HXMw!G9+Liw@>WKWVm{}PLzh8+Z-Ei;vr z^!q37zoViIZc`g(p0Z62B+5g}c<}a+49b1sX-?pgx_DR06u)6x!-?2Rw{%zT_Q=@HF8V0rHB&Sl! zEj}SL^|-r75jjio4~3!W@unawG&3mnaC6lWqMU*UbPA9DJOxUg2@(wLWGLCe8~YktAOET${F#Ur#ReWxJVkN+ zRC_3a7nEZ|4|74OMJ2y5feGg;fD%*_kTIvIur8#+4<4CTO6oQgic!Sp7Pt?i$-9{q z7PaGD@vZIP1vsPejy8VB@ledx`<#aPuyo~(Xy z!Ii2yP+E$hLdZB&M|efqffVIi3zfSZ!RK59rsyFT=0yo=4*tPPa^2z>=8&;j_h;1% zKS3>qJB%S`Yikm+(sn`t>3v8hf>Qu4?Dp>O#F*@X)fyPyZ`LTp1p59MIPz`;_L`Ue59=AjyF z=`` z+1l!?i%?gltJxrhaHi-VGG#G0k*8=kD4~2}+^!<}@0wBYy0(Gy)#-G~>#G3=u*7Qs zzkb-S`Kqj!HX%N9vQRU7&J2o18Fqp4kqA=eJf*_Hs|_s7jSH_Xl+eE6msA8s;zJU! z@1CoVKu0tr5&HqkBAz4vu)99?)z^=J*G-yr;(qOW-r*n<0Di&&n)dF05ZEa>p3ml* z3|vt=m05VRvr^JQXgc(wm&o3M(_gY`f~^ah!9zQdH=Vw(A#aOq6Hdj+SY-9i**^AW z;-b-kicVl3YypAR65QH%DeD7g@a~v>3K(8}Sjl;W8>Rq;$HbUBx9dyW!#QC1>A@OT zFw6|i7TK>Iwmo}_;XB_${_cNI-UpC)#nb$U{e>}jrUvhFZ5@=H2++J)w+8A%kzH4$ zwkuakU(SOly8Vt3;GSEb`wU3)ta-n%$=U86^`dx}9ndfc;lS`n{0soKj`4tD1?#Vx zzYRa#2ZnKVGob0m8dR3#m>a&F+#U|)1c1u;!~cQ2|MS@2Nya~swJUAVk9+tzL7Iai zBuc3rzzRx4BIp28+2L#I%=RNi+Y&})3!M>X{$YA#spA`OD5!%3iz0oVzy(d5#K5CL zA@j+(XB*pz=?~_ugNua;x0i3#Pja(dy}zd|=eIi&rPMt@(GI*nD53yXUSOrO`(x@2 z5UL>?_@wR*fDf1H%2KT3k2R@`_JO;7brDT%QFUNCPUZZNbP7RKyr6qbly&Uo(FE{a zO}>twltLtbI1JX**kd9|J_FR^$nQf+sQ}qwrBDjKh}l-OU!(0HpiX$A7JdnEKoes> zO$Ws2tb4-&$qes)b4k!v>z<*Hg4$_~OkHBju>mt)tDo&64I1wL{KWPfb&E^)p%3)B zPw(HkYd6)|%5Tx7HZ9Sm<4;YqK!5HVH`ombNd@&!=5_!`l$;k#_jyB_@&i2vcjuz| zJr^S5;5PQ6Tt7}l1!B?72_8WZz1xg)$l7V&RI05T_8m!<7<4Z28ha?|6BgC@;HVDq zJ7{l=dwSOVStt1RIaBu=bvyTFq(>^qBnM3AK%GQwilH%j+EQZ1_299&VmC|TK>x;S zaDnv(RJw)urU3XxY^v##XGWF=ta(3I%B5<(JXX|jP~$EaVy(0`uD%64GB&u5Kz~Vh z_MFMDcvLfLUHb*9rEbV=f8#O0X93KJtra-XBVsY2gaK-d;}H)Lk@Yc(f%SA2k86p; z_$-7=>*djf$dJvt<{Pg#sp{4G-cD0$#yfZ22G9vPV5#64h7I{8Aa}cbXCG(_>_Ree zHEx{%ujcv6;4+?IYKilS8@IW}IZvVlwQE6#Q5Mf5sT1a2-O3yz)O5@lhl0NBc0iL> z6U*aToL{IcylImiS?~0kecimukuLfYS8$_=)lQSx&TQ%&rLdIHr5ULC((F`Ep*`VM zJR^5F5SH zsExkD`J}Pw+(ImEn8ZlXQVvR|ig>xV8i*L1DQn4MrznsTNPqwDS^_RM5V;Di@H9Vb zt%ZFQhynBIw7>@d5ovEt9>}JfugGeE0aB80$3{Vv;;s`Sk^5T4hj? zJ3GqPUp>5f+R9{xigMsT-hztI`QVFZr==dhq$6<#dWQi<9!r`29ffZII~}P|hIw z;vA0>`~=ko>S~r`{8WFIWX2Q%<+kppF8Kk#+7cCrpS!dJlXyj0_&F>2jVioCVkd`D zuQJ!x&AzI#ZZLwk_lOZS1QcE?(Yd(gBC1#=(`h6nKEcCtioOCtOkR(oB%l7E9y;G_^+;i3subMMW@1xc z5Eq{HqqkPX2HMN^rE|>!76pzx#&#G766g}qGbx{3(c*<4(|zxS?_U6*FTOdT!(fr} zW;Is;11+jqEi2Tr1kV73j#s-y!R*-d!}^k0|)LF%f8d!tcNOoB%3 zpVrQ0YgN9=Al^uD{V=)^Gg?+PKE_8f3XHkd1i+GsXKTK{?X;FD34T}%o@A8dd&mEX z+-h7a8vKoiD`E9hd(jIW!PC*&6on6^1)-1T_b<}}njplvdj6&oJ9R4)gYw8;YIbRt z)*UZi>MQKgcdaoADjZ?Iz9C<}R2+xOE78j!&QJ^cogeV((UL9PN)|HXR}`}cJ=sge z!j0-N@~fYCNRaAml1FuuA!MXEK4YZ6Lz47c2;(NH*$t}_dS$^)#*rVjSw4hiw#ez}sH_8w3A!K4Q*U_&H zODW{dWnjzX=!*5i`m?DUd2xufTtwbC!I|{m)-FNqoiUG`UW7i?+ zZ;v~%#~b^B-Irl6U=}X~Qs>pHtLnYV%-@Cflhs<7wCcI?;{|CYBCw`&>`O+3D=^IT z*M5I+f6NSOAmxOW>c*T{L3vv`TekVIh94I62q+vT)Sd6ktOsqTCE>sjbkN4szh@s= zYw~(adsl z$dRMY5y9%MlD+!}c4t|A=KShXb#& z$L1~z7celtMaSDyrNEL+*BFy7Ky(efu#sNtCMBY5C{gF}0(mXA#@k z|3!qIJ{!FM0ohwHe6Ump3=^BoBEc{h82&9%qYN8>VU4{O6=0YS48LaCb9y_o0SDz5 z(-iX`$eV&r^47oKAJBSf5P3+B&#`;T0~;I!B{p_vlD4mXp*s^M9@`7KR6{Wpl&~~Z zFg$DSc4i2dwuav+!6X3+?#Vg zxaNL%d$v;KO;ra-(EeR!@elj|k7xg3DfX>^e%~xXM1U4}={}9xg0vOw(Ifd23m_|d zN=M%LnwP^0WC-8~clN{exAWd_v~n&bS{TVLIhn%OMoU10+Y`Un?~24)IaICXrW^t# z@zT15xVZuT%C1qdbd6lfEu$w$UK3AH4>*lXjg==ERD3aUyF(yv=HSBIWdX}PD5s(6bn+B#ot4!V_ zF~+?IpkiN<-ahb==NYhT(Hq5j+7_p^0F8gb`j}#05y7>QD+I3#jnzw?cPfHi%S__c zeg;{d|APB4BmiW(r3MNi-WuE9B$DM=-pE&b?8yiz1DVl}B zjjQ_2d^1NC$+Ov!bEU(04*c_tp@i&^o6+j%s;#P8mxF=e#o-*g zV0rnmtEO7AceCY52|XKEdSyq0_h!*;jYMLISW-DFF*7E3bAsA(uJ@|b*S8MAiQr+5 zBsB_$gy=!x3>ALn*FDegwK%A}zD^hH`$6{>Lzmyjr^#Q^H~oG2x`_$vdw7(;cXOQc zx_f-we}%JzM)aVxdx`@HM~o3!`@MWZ#<2RK>bjHid(qZu(ckG z3tQ^a;vzO5krBL9Wn4U4LarCvz_?{<;Lg2&X%>n>S!Ibvv;c@J1Ok$`SWTE_cU2ApCM0ub*P>r zO9YR%kwsk$+vf`ha+cm}p=-SqReDe!S@AZc&ayQG!+eZ*TeF{uJiqsIAiUNTZf!UnIY;zy8Xl;eD zTwLPFn&ZX<^`tlFZGM6vYw1Sj@)Y7-cF^)9C=JBFCaZr&tn?k#9308$d!h#=qOdFw zodoRt`Ofq9rQGlPuDAw0(*vunSuDeS@SW@bZ2Ik^Gd!nRS4XtXvK`laBJ#ilZ=%(~5xmq#CwGyO>)j^q^YTX3%@Vr0neflE}ytMsD z4oSrWytQFlX*H3dk4`!mFqoK)a>(v`q66)$C>kn@bHIL8w81q26Nd$*@@xW0J3i5@ zmHQgVvgfI2NC+$ug}C;Y8x4lkIx-(DGfkrsK|l7e%|}}p(V~qjtQdS<;&#cuQ~C1Z zh;d&0KzN3wWkgrV_o8t9g_j*BPC`kqUuAqM%<#LQO-Kw5Q0$qe%~=8_kHs1Uaj0;f zUrP7;#TY>6xwCr3z=g4h?3{iK^N8TcVg0QdUp(@B?wP6?=0Fd#?_@*_mM~wTC10+0 z2ZnOXyl)|*SAeV>pGeG_Ti1PV=^Ns?p_fFMFqLx>;KSMV5|=(|t3bDrb@3F8F841E z0FZ*Q7aN7`v7qFdf6OPD0tjUQ2?0-P#~2TI9_7ml!hmuk!T*7{^v1>;K{z^6bo{7n=Cn`!CL~qNxq$17*{^HJU2_pA9jRthMtn0f9`5_ zgE{UQIG@lDz&QSnOx{gZP)jOa?#tb(#p_tUS>VJaL}KrQ26!k`0FXZ~2=ceSef$46 zNZgiTZFGEZrsc@Oq#@6Ee_9{c)2H%FYo*fybf2J$EnthnqEw)FJs(?cOTkZ^I-D7l zdQ|e5MXqn>H*iqo?v<7OWwQQ;OncIJ6|((b?Y(telxx>7u81OvQVK((L#G2sr@+uC zAxH=c2-1Rx($WJAF{C2p(2aDXIJAhA2n-+~Qj*g7y9T%W@!8Mz{k-RQ&il_f+rRYV zea}@ZzH6=Ty6#u{7^fehR0&QB6Tu|`KP!Fs&_57}2!6G0=q+xO>L0Pin^xgYpe z<5&xv9DvR8lihhC7rt0M?d%p}cIyYpbAbYsn@3kbwe}Al;J-iYx!LlfE{Nr=@=?P7 zeg%k`{{G=J*u&QnqjFJ}=~TFVD+TukYA${H*+%exf?Tik@4NUTMgq5^{;#tAEt?cm z?%hnshu&<$Kl5tF3L~6Hjf;-lP zGRp5THEYC}gFv`*M@;lM-U5KFrEYucNz%80s8pvNy3=z0Y1e;_;e*nzqCjjd_ z0&;tg0OtC&a}5-HkHo+r7o3fb9d0HdE~N$WzBA5V;%K>l)P?_f?4J$!M?!uCD*w~E zz?ANE=&_mZF}51-)Eg?4=5-KPm|pf#|Gqg1=X~$~)SKy+ z;g*q;;rjKt*TnN|&-sPve37(dt&oIu-=JcrHX1(nnUv2G`|-|A>#OVb^fY@}1&K}v zoA;U%iIvAZPxx#~ZkyE+iKI1rD_J^loLaKDP|&-d;Q0}Md%aJE@AmhZsPE}2arye+ zVCYRc7qgG=Qj*VicA|&u>eMPePkV{)3yIvtp!{jiGdOB%s1tFYmS%i@%mlMABaMW7 z35CE;_iD0X5X?n>kCczWCN2mi${fQZDdK^%B)L_W+74?oah7U-h-BiTE~V|B_=1UQ zIdxgSqUuiL^Wy#UV(m8vM)f~2txromXY)!HQ{YK-E<=i%Bh0Gcw&g>8<=KkJnPh`9 ztJ@KD9Mjf?2alFZ-Zy9$;AQ~^>wthcDarBB$!arX&W*9!hK*o3ngBb;azj)eGGvR^R4klkt& zWqt~A2!8RN@!0Xxwv<{rkddd@i2NdyBg{3MifiC-arJY_%4BRpXugCCp$-c@qwxQ8@YD;0G7 z9^cz=8fMCa<-9xgsLj#aRC}BD1y%Cz=^h8K4OST6ZBt<$Uzt`m?lIC*M`l&JX1_bfzli$&{=n|r0hEXp2ru1njH0M|hW;hE zDoPFsVY1`}he&U5xdu$ap(Q>EcZhO?@1FH>Jv1LRBT~hpYB(xgjT3dqC6a`8k|@Yn z!|7Q98C;@iT`?c-7!z~zMYC;B`?$!r%;-3(X(PFwteo7QO(XQ566x;hr)TXpj1eX@ zHYK|3&3Gdap9E?b5)97s1=VjhR36_jh?JXG9pYg2@-S^*WUc2}3}-AQ-FEI&`?zcv z2W@t)@4Y$}yJkdKq3On8p|MOfW{P1y;UzdPrla$s`>)PRr|I;QOJIc~LIEHgo8j^#h>KT%b8F?d>8H;U1gKnknkCwU}H*2lUHX6m|4?KHEQ)a24mb>Y6 z_>3QJ;Z`^g%);MEtaYMo&D8s1^ZmJMm)E0~xg%3~vKg&Co5nF63^d6>F9<36*OFNV z^CG3EzjyESkA_628I}?0zcRP{8lQ1Den1d`s?`ary%fQVhajz3zxFov`2@7MGc+Vdal6ALR=2&AtBqabegO=*nnPS6*Epby{ ze0d|)lIZkEm}5DR&*T-Svyh5 zqjS|nPKqrza7>vV@`z{2V=rv8gs|uQA`?s_EQXyzut|~m)x11%W=H$Ghmu&pYR`FC zWSlaJ6@OI?*uQemzdJukO6bC{Q$$m)QRGO?CU*_7&q2p#V+^n2lI^bURns)=+xp() z<4P7W!K7NVx|ex-HWD9Eg$5h&nxv6#>aYv(kU*zSuZZ9?-thr;2ES2gIzd^$dUcaJ%c>;0THbpYiiq zoYhm-W`>uNZ12*P_~d1VR*u3>h9uo0~NqAMcTYSbe)-f%xlP?$#h+7gWR&_nHT}KjL6=S!Pf-s-IF#o28jt)r` zGghKv-qC9>$&~4CS&Xr*&nX}r({iGn6^)}&u(FRAFx6=(WYT?QpekYhT~hw%MMXGxZT=6 zb`GCxWO2pMC&X=G0;GGdua3*r-rHuV7Iy1Yx{<~>=|yO|xsBQLlNxqk17lBFA=yS< z-_2CZ4bHjd>^rOp_dfanJJ^2D&o^CG(-8(Oqq2#IHVgeKu$J857qcGUZ-> z9;s(1u3rn3wv%Hx^R^g`^)2&!rgM!+U6hGij8IBozTti5zO_f?^_cmx>*{CBy=M&j z2_-dNrtDhV&Gng-hZ4fjJ1QL<>PR-vaW{>XNcC4ieCNg@@0(ObWw^Ezx3e- z&J1@XDj~C&`K@?KnH6#9h2-f|Iy?kw)o{{W14HOnN$@_;YdoQ0tZdnb5X^w7k3VH|Cs70}Sn=>KLgq=mz zv5tvpu1s`ZG`fMB9l6RLD$3l}?U9@&X|)gy74tr$%gPa7R{SAtrv0r|8aT_SBN_$i z&SoYgi|r|jZeCdK9HmjQEasm)b9nlSY1$N$v~1uS_EljzeQeZ5(hWDkZ-aHo-%8%T zPT9&YzfZ?CmSxtg_=H042~&)i40ElJLVPNsw>`GfjpwTGS7642NBRU8=%E~irDW`O z&5C3V!k0)}1k_W~Z^x(2$1v&kjOuDHYF~y&3eV7Xgbpe{lt)HWLaXH?8QEE@t=D;x zWFo!E44ep*{>`VJmu@5h1rdivadVJCm3;7F0G)^k+hVREKcy(E4uEZ5F=1L9=e0AJ zjy$bcbIClNcvlaDh!jsH-nm;Oq?CP5x3{^kCkq%o!*fw)T24e}dhpaOma;-^LG9vb z7A0gFT}jN>V2BbgJQAPK1hd`{u}0v%a|wVEgN_^^n-P06xkA;Qqa~xlG0i1Dm^5F= zN(0>N7MK^45t7YuykT){z47LoHs0=rPpmszqMG?__L;4Bxb>yJl3*XE66EN?S)WlAp;Gsm3@L{p)y9X=UHdG zwQylxiA}=6wLBRM1d3WZQ{}V(BhJ&Ip??!juye!LPNZ-sXKUR;cPOvD{rb-Ty=EKcs+_%c z?zc)OU%jWk1fR=G!Jfw#$Av`(5DY?Q6yTIA-2T^)xA0VMWdfy1Z&DON-LTSoP3o4& z7`RDK#btkkw~PQI@!gXbO0Xxblc&cKCnZ>)_0dYO&0B)_?5I{|3wC)^#fM-;SqJr9Hp>GHgq&go}O0UtP54R>+B~zVhKGm=%+r>`PX;jV{U2abjO( ztDqp=CA@tewfo}}*K~(%2er!WM%5c9GHA6PImE(p=bkeU6k$TdL(&6pZ0oR~q*E=$tVhQRZbl428Sr6S>Hw)Z;AwNAkbXNnE_Ha{`Zr zp;gV6J|`U>o9a3F0T(DZ0`yjFUE$#?A-|?#TffssL);owcvv;lI!gVS*U9qQseV2$ z0}m+N2E`0;BMA!-H!=F(?3V?}(_{F-U=k1#;0Mowlc0>2S~~HY`QY*g$6$>R@J<5m zEak)gKB75_!Jk6f+8~hix}epsk?A6Bu>u~@4-cmkzr*UnPX)aWL&gDRAN{|5fyVG_o}EOWKX3j3CHL>6Ne^5ESA6E-M3` zG$R7zX|i4Kl|w$!hbUNHf)n}tE|qaKGZLr>dvvDdWwdj&|LNfgZ*`atbdewR4_Pa**ZRX;Oj8IK^nes+=#Y?aK1zgh})hh8dt5R-!^ zn}Nq;^%IQBDD zgnJSOHdE=PxEeWZ!tyR{x_}dg|mH{ksx@YlIRXA1%-=qA|{=H zntA_D&F0jRc_Ul(?_6#53pj`KtK#+yx(8J-89YLj*%q7^PeSoNOY1+hb4q5N>MCrU)rMNLr|bMruFk1->FWaF`}WluJ1PnN<~2M#GQR% zfdtCV*k4I#lGc1O#SMyI@A_`@%VgPnDCY8MoZC2U2L14=st)Vba&?dgT}BQDWPLVY z9t9_dDhC^StoQqe-i*Pg5_7icPGX*`(^40a7j^!2t53a7$g~-G=P=UM1-M8ZvR^#G zSsXUN`51gF)nhQtNxUz$yW+v5S3^?!R60UtkSLveeWj(*x~a4%5c3CQ4eYGX8`M&= zs!R+`)vn3xfU~>_63>5D53c?AF~(!w)BtMqRUQ*rk_(R0-MxbUK>hz0C?MC#0 zWE3Rm?tfnc{!}IC7n|Y_ssLb}xLiB{K-mt*9Y;N&K|r3&m1?3-!6LtV-wpvUm8B%* zU&km48kGk8`GO0GGTT(r_uiGDFp64QNAkD``ctmK`rL3b>3xELVl+f7CV&d+8D?Bq zd=2^JX4ogQA4xx47Wrr&q?y5q#@v{sR!c!)TER8bQys3p|Fn+$|?RpVv8iv(V zr|l_`-phYHKyWycYOwp+OUP!3i*7Xjx3|k{(lq)JE*X#&-~hu-XqVsk@fm)$}Uk zg-ab&648s_YV0sy%g6~AUsYsOtbV*?mu_$Ddl?Avsow~3$|wer#J&1HXPwV^?rgF7 zNQSoGA(F;dK2XnXTSb<>6(r$*EnY5!1;#D7ch<`qqEb=U-x_NT@|$47TJmlzUMQ#Tkb zODRn-cba|$ia6~$lIoOZ{5dtGD(#3E3Hj!sN0n;TLcg{X@u-D$)dSTfX#^Y9>PF6b zPyWNGZ$qi3j1rBze)m_PpyQBq4TBAD(K~QuIcwR5%-M|9NKuO+1w__AaVa#jK}1*I_o)DOm@XeO;u5nWPb2 zJqr>@2>Rcg1i+z&PbHk^n0nuTx=z7k;0sZDt*wq^B1BOW?*`0yZZ3LEA>TBVI(7kB zx%ImZBY`GUN-yv`zNR;5quiXOt{trIr{KHOl`q!?(OB}2>&fzv@5>fC8&|-dbD1A5 z+43^>T~?i3G~@V98j#ExdeDXCG+_ccQ%X?hgJ#P9DT|-|q$Pq#Pv%8Rz@u9$mocN> zYX~?0Cd7_-_4Zu*rW8u$u&AiW>$FK4a}$$Hqb4=0@k8sxE2bDm{oigQIn>oSf-AG< z{&RG}c7J{$1-^n4yIQIC4igQta^s6qkf$Sfq#eNsx7s_K(!NLvy@^L;{42Qm)&AYE zH<9ZgdURPee$qij3F&Ae7pL9<8DRkRhm?;#W~vRj?LzNA_-I-3z2^IH0cLUxeIA5Y zf5NlB{?vFMzjPm07|oKGmbrLsDZ&wFscKPft^J!^vD+b1(Uuk)G2pmvIdK(6(vYiv z;uqf0(kFjdS(CM_SUD{BgVHdm^$4_O@rBb?AKf!QEt+{$FO3{UKT2ti66Mh zT=na_=g`j1PG-`3woCE2f|0{-H|qja*$OG;xp{||WZ!juWc`iT#>_{_t+C}EN{LC? zx{5n13dMJO%?BUHy+@lTX6fESV%^^yo9)TY$0_irQJGQsNKWFnVviEJxJ5tHgUOi9 zbc1P@=j-pzRPDd2urhQg8G*u!SKe`m7MFL$s;%dfLi_o-lHiReqdPWnmx!ewr$nb z2+vo}$m+H96x_e3S7N87an);ULE?-fye&4w-c|4Gd*hpv!7mno%EVfXqH)fc=r`Jf z^Ch9(f^wS(Q~=w<2{mBzJ~I5~*`JOZfTsC^8ufz8i8C(CZwS)84i8qpzOX-a@tJ^S zq-v6Q)s1ugb^wvBF25!2`0!Hp-uhg>zu+I1w(^7Xr=q)#1GgVvyvdiX*3$#oKLGBN z1~dEVRVtT$sbaPB0(A{*(2mXPr^jWCcCMN>6H$pcnsrO0RSj4LFq))MyP~Xruf|t- zq~F(cr9woeqP5bpv)MoGIemWZjb?`53S)zj*=8!I4_gtwJn0|Frk^J_oHdVwP6TvL z5wvdKUUAC>#$L`hsNgwGanZCrnte4>YCl_kx6*xbr*AWbS!Hu^M81JChArm2EU3qQ zF)5J!%^ClMB!wnS9H*iypGmkmkEcPnNg9ip)bC`i&JHqPv<%iV~Mx$um}Yh z*~nwr{;s|QT32G?M-ysmY1IAKsbT1$*+GVI$(^h{&{xWaK;^k1i=tT+#?--do30_Z zjG=s7*QNIv8pu6s%iU%5j{x0M2Rwe3s~ri?g{Z<+UeAHu$+7svxH!M7D4ZtbdDJ4NCeCT43h-6vr^SS;-eSwkbP z^tM4EkX-qRqi{vzXbZWsIPX#p@^Ut)d`TXYG}nMo+DqXka-HZW;Cs{eWUu`gICR-rX@U7KiqLyXOOUf#WUo6MJ zF>d5QDP-xXG^l~R5H|~ew8_SNn?iu};O}1e-^doddjVOwU~1P3!cZ={e{xs!dN*Cw%hRTSv&M8<(s9j5il>?YqgbG#|NRRo^eV568c){6bY1`f_a+#65P?~HAk-kQk(FKyU1^(G{vr|HCAop9)MtT86a@UexdpC z3!DUc_06JMCIGYPze0@vjivpQg#PV`{geCuTbk(qpWHu$Z$j$cLiM9RiapquzTyw2 z`$yB#4-QAgzAvt^B)eY62YI{3UqgCa-ScPm19RFGWPKi}16M+UQ<=Xj6aJ`U_`!g; z5(Zo+hY%1z{s*D#Ulol1cCqR%5%9S0fmg>~2APw;rSM+8;+-M5Fr^@C7kynYKPo#s zf!$iDz*3P~c7vYVc*Ly~E3riVye=a2n1!lUtO&n-`}(b|`)`$>AMD*4EY7Y$&OW|b zu;l;c(v3o}Xgki3zptL9{?~lKm_PV{@C$hHvkLRNlx7hWa-1KzIt`}<625-~xvh#| z6fgd&5{|NC2e$ezELDLVhz79*U>9h+Ixld)klNKY=c!bpcik?J{AN&L1x-CKe*OHX z*@%=8tq_I?gXF}orAE~y`>yVK>XKWR2iIm#RhfOvE_PiAj{kDExL3K!kDVfBBRRmE zT5R(C^UI*jn)AXny=qAqhAt3u_`F?w80Pz>x{9)6uVOlb@MeWA3G|VHz8w*s5sqxh zP4J)R9LfdnD-KX!sK&4DJ$-8bSixECP)Ij%``nvXk%6*_9c~N_T0$Lskr+7mh8MJ}&rZ=55^%!hWA9^RK(m_l~^ox>Bjm zPTKi>>1}M};Of@=<|(%qHe=go(XsChVHQ44Hgi~$R80>x&OvmOB!%njHR|ou9rthj zHD)uCtijdlb)&`=nszQ<`CJ}N9k!iMd`!53xUroz&=}LzaqgklkO}!GO^;@YD=k4N zuDr$<XQu!wmd>ACfGdj^CoSUM#b%DVE2?kzm=W67 z=M?sq*{=0kDg$D%T;E~N!0ZY8uGRSwJp&Aazt0KZI~W)P4!CZ&KLQn8)MV`Uiwo(L z``kRO9uqE9Z(L~}ZhNZj_%H+ai0Fp$MFQAr;cr-j0e@i)%4qyISc6agWDV8;vZ)G= zHt0TyL2y)hej<`mjdQ(1OdNSX*;c;{6{joCd0WheMvWF1i$0hgbB`}I%XVvCFW)Mf|!TVRiKbit6=DcW{P0)BPMcjMJnxZ|UlSW@) z?^@i6w6XJV)394|$UPP0{q7r9<^G-#>aHnTm+vgwjYSg}@5nL=t#K~$9!}Bct-mcL zyX-(5qtv0^NZh)^yJC6B$FQgs6l0y7aN!+ZcU2C#JNlZ;<^Z-YyUB)z+=MCA4`GUO z;^i>vVHfNc9^VanZ>7k8LSdl6aJX{EZypkB#bJ&;pjF~m>{Al%etu)BStr)PFgP<_ z&uQngito9wlHiinhV8jb`CGUIWGyy7{*~}w2!_A#y!DZ{&RoBBT3bHFI*@Mab0S=+ zu;7GBvkcO`zV_1M@vG{0%=T3tV}0H9VPvdU+A=D(7pNuErwR-A?FaX@){e1XExFWL z9LBZExS$zt)uf0e3u77wyuI1_-u8?lVMYDs3j=`#<5F=_+-~-dZ*+f)74D97+Ox%G z6sE@-#_Zff349d@sbp42Z`KyMtKHB0n)Wsl-0Jp9@gklGju~Pg4$^9W5RrPUJ-Pi^ zx$refpUjk5%Bxt4mX9c(|F*(CTQY*mE>J+a^S~#)UpV-l`|{yrRwtZ|sn5Jju-r*REL}rL_#Ztj=jLu~^kd6dy%g&*rRUpliYuWQ=|5 zPTO0V&SM$JGV2->seE_wZcu80h+`qSkH_tRO*+nG8-3kdITu*(agailY@9|}S{xEm zt|9kAoFur?mDa=0g@VdZJTQWlO#qD_(dt=VGTvyxUo?DBUt41SvebG+W5@3B(QL!p zE0!ZU?CW$;IN*Rg6klGvB%p+!HnOLAv85ci!o*)ViR~2oB!KbFR1h#_E;mC3(_@(F zC<_nskBgOpS`n(YWbmv?g(pHYJuQo~Jf!*}dCMw7c2p1Z1eI?VD#!MKHV&Fe2rP1s zvx;8OSZ$y9{JBeDmN*Y2|Cd>MT1GlqFESH+yOM{7x{*SgSpEGDtGx!!>Yik6kYkvO z1SG_gcfkAVf3ir6PZx-J2MJmvL`D8MB74mr$0#4)r|i>6HJ)29qsWXi+n%q@ z39*Key|+IQyB?(fo~&Z~p3}Q$6wvgIxU4D7crG_28+<&SUyEKs>pjF>ue&GKajZ_0 zcCBn&u*4|o5+Dp}Md*!J*A`DyeO^l|nmN-YtXH{ua<)t@xe-pQJ>WIG0c4KN| zy7e-=&-ZKAiLR&<&Z22ek+GWm#O?bxrDnfZhF(>%4Wrs5p%GF^k4zU(V*SXJsG~WT zR8oAVE9G=@QR?DmNDUDoMppYjsErLKJ>Qhj{(kzGOP=(r6nA`dJE>yPr}=!@V##x_ z+-bg2cD$shQX^Aqf9KW8o2%(YF$k3Aq_CZRQW1%&s2=9R4^*)bCH_mwblKKgNIzz+XF)>P;}+_mgU*YHpC_#(BNt#{acU91!HTxe9 zQ;e`9!7-t~Nw4MAM5_7t%&WW|w~6)LFN-b>oHk5i>nQ%SDy}_TyjIc1!w3{STVRvW zdkANhrhd_PK+`&Hn)32=n#)$#8KBGzQ|cLQ@hX)bRiD!3jIQqMJhaVD*Xzf0PL`E7 zwPwB$Y@-_6E$8x@w%vY9N30llzS9MB;fvw*&B)=)5l`%JF-q+DpHUBvpnxE4mX;;G z!VT~FMy`40!INZ?y2prYpZjggJ*!@={nrg|To9}kQjCuZQk~M~35!(ST@qF-c;V)7 zMK%Xzj6kjRDSep6| zTUz#J-;TS}h8*&huLBOCx~N5sxxgVOKY$@3PjmUA!(h;^Um` zUThRV+hZKI{qerYJOR1yrglw3S0Xe{y+$8S$Dj5gFsx9r;oxx`fo+%Q^j_nmx#?m1 z7d#c%C!vO39V2`;GmI|QgOOzs>FWp_V$*;@fKwG*q_e5MmsnUuG4 z$Fv3(yk5(Ae-K&u^6KPE^n~G+an(UjfK_s=?YV6(;b?x92{UApEm_j6h7}wUALF=FNho7<_vOn1y|p1uzWD>UDcy#8!DmWjNp0vt z!sH@yp|#fCp>;yTm2~&G^S)y z8J9}$CsQPYPMHfv?}&3T?|iFi(14+L6W!ZNDLNjZBW1b02s;9^vJhv2+2hn>p#ko) z2ku>T@+Bs?j?_DCYS=ozxNf<)@0h8ax%4k#VcBsj+6D>*flrr?0(lS`{6rwfyybv8 zo+1FJ;XOpk#tLu5yPR;ZN7ifY&6ouYy*R#=N*tA2P^0F$G@=kdOvfd3Xqh3t`au=+ z7#sVfUvVvrOdI5lS9%f@ZnRV4ywJkEVcfkiTyD!TG?wNJeVTOXGtO0?vlu!1GxGEE zIR<8dJmQ_ulYt>4HyX17ah<1E{;uHr!!@(j4U#!!7@>Cz} zEwd7ji!d$N)Q>=A^InFY17xf}mxeQAzno}2-S}GWAm|PV4;AhvtDp>OOBa~)T=Qgu zhS4xD6&it$Z$DZZ6P~!UJ$O&I_2hJ-#wUBGP>2R>UGq7OITkvdxIco9au%n>OT$0T z3YMULg;Erd^p0{bQ8#dO)uS|z^lMqXbWmbRo-Mst~(4f4f5xe3$MhU72<%nfe z!?=g2A%c_{rcN)RdR}UkO=J!6@=ZOns0Fj~LDe0jyUCH2ZW;rJPCfV6d_w8W0 zatpd>^%lb9Z{&5&Z|}+|lUxB=ClgxBwx|E@tFD~F)kU5trbI(DX05tkYDu}bYWSO4 z_PgywL-t*qmXK+B6yudv*TeL@JdAim55F}QU%c4m59#JN)hW1TR5~5_p4Yv>a#4P2 z#G=9Tqy||F+fG&BTc2pgI|<+*g*7n<&NKw6ZrUsXUZ~pBPF(ha8?c6=#MWd zSP{3oRGD!~0yuHT4~jT{mise-apcDT7bhm^+1>Re`hD2*e?;j~oWJjh1U z);ESTIxr8f^KK97O(*bL?B79~RS_f(2eF37jJJeFes2js7fYQzC@`Sj!1>Jb`)r2V zY73Q6d!Gs`f$C7XZKT$9F@t)~wO4#|&o24zfvz#?#c6Q2%a_RH1m51lL)+IXO}m5l z8c=sz5n)j1E#j_Oy-MQ`089^_wiFY>3| z-oDP#zTQwHzGb*{&+*XX@k!VkJ&S{hoDwOsOrwL!RNiio!9l8I&2~Vp?;haPl&Ef| zJfzu+fttPlrYAGHyZHEN?Z!}tKeOVv8|_NI#z_`eM4DdYg$0X+p|azJ&I?0rbY8V0 zr_-Cl!d|4)-%YAMpW~dXQ)f_W$N^SD?9~%^bod+~g2=&tw&I9Ugx^Vz2j;C2mxDRx zoj}f3v7Hm2xICGB__$jYZ~cS&F!w;m6fI|=)1}_>j_nVu9!GZZ z3cxKU$4ml95GbOe_N$;}8;Y66DH(O3G{~=9Jjo1cJ{({Y*>;<3zow*rhMW~P;p1PJ zEIdmIE=Xbwin=^EButpv^kyYh{Y;v|6LbFqF1u|t;kY!n4@>r+Dql9Sl+C$%gm(|p z&hdNX8=Z+ZXIC>c;O3%#ewDvpNv7LPiBr6%oKuYPKF*&U!+Ug;@4@_=1CV}C7%dj_ zmB1oFyA;dpUSnk(;=(J=8ZIf$|HQ4gIrNQVc#(VSYD7N6mwuwiO`QS~=1Uo48*!f` zKyc_A$sG0mupWD5l3rx#+fr~69a>H765Ecms5kC|+FVsJsBvJYJfn0&;-=ni4ssgNkA!O~S!yLm|8PjTUZMcA*!S z!3{oxWOFhet(W$F>qHuN5q`sg%-*8xaPB&d%;qd|i#3b6p?zXuwk^HMI%FY|fW)IixGnnk+bB8i)@k<`dd7E*am zTx}7C4jT^@L$Ro*W$g=lF27ZBY@Zo_#fk-2{tWm)-OKKlCB(~qUZJ6l4puDDx=IN7Ulb0v6iM^ zu;0BjOKn)H50!_UMHqI{h>s-t6bQYDq~DR?!6J|ARAnGgj(%H5T{@E_ilPO&q?L#D zH9n!bsV(2E^j@mxQLUDxrr)k{Uj1-l^Q3J*r|YW9lMLVb4AV??|1H;8`6?<_2iEo( z!-<`K;ys7tdv$@1CQPp3(E6*&N7nJ(jlF8;5&e&O+ib!d!kyUyt|#;#aa28?4e4+Jw6%hPNJgvZDhFC6*kNL;WOcJ4o18Qnn5{GwYs zViMl*Eb5@7>FwU(C`{H<(DIs0Ko-HmVf74|$oHC5vt*CWaEYp^aDKb!>^OBEL-NpM zCI;=W%rGia@z=Q$Y5{Z)uX(#Z36OV6)sYB+sG<4Gcsx(KAWOH0C0lQCYM~;weaUNAif?~sjKWbHT-P!) zVSKS)_`hSA!O6U4^cKKtzVjHD{lH;JZFxXryjcH^5(?*os?-C!Nb)CqJq7IhAJPS` zA_J_ES7DV|C;n}NfO)tABI-Y1dMyf&?jvBx3=pKy@>Vj+AMob}?w+hidzI8j|HK|; zdOMMSa4e2Y$&c9_fs_At;O^$$Dr9>(HT)d>NQI9S=TF#5?fu@sIa)U1Wv{?rZ&(9^ps)M%;7<`__ -3. |driver link|_ supporting CUDA 12.0 or later. -4. `cuDNN 8.1 `__ or later. -5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 `__ or later. +2. `CUDA 12.1+ (12.8+ for Blackwell support) `__ +3. |driver link|_ supporting CUDA 12.1 or later. +4. `cuDNN 9.3 `__ or later. If the CUDA Toolkit headers are not available at runtime in a standard installation path, e.g. within `CUDA_HOME`, set @@ -76,7 +75,7 @@ Execute the following command to install the latest development build of Transfo This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`. -In order to install a specific PR, execute after changing NNN to the PR number: +In order to install a specific PR, execute (after changing NNN to the PR number): .. code-block:: bash diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index f68edf155c..0bce83d98f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -164,13 +164,24 @@ def __repr__(self) -> str: @dataclass() class MXFP8BlockScaling(Recipe): """ - Use the current scaling factor strategy. + Use the MXFP8 scaling factor strategy. + + In this strategy, tensors are scaled in blockwise fashion. Each group + of 32 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E8M0 (8 bits of exponent, + 0 bits of mantissa), equivalent to scaling by a power of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the MXFP8 tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. Parameters ---------- - margin : int, default = 0 - Margin for the scaling factor computation. - fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. """ From f9bd83ca95ed73c18e6542e058ef11af7c449c30 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:46:13 -0800 Subject: [PATCH 073/239] WIP: non-paged, thd, no CG Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 8 +- transformer_engine/pytorch/attention.py | 84 +++++--- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cu | 180 ++++++++++++++++-- transformer_engine/pytorch/graph.py | 4 + .../pytorch/kv_cache_manager_non_paged.py | 13 +- 6 files changed, 243 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index c646af9722..dbae01ad81 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -207,10 +207,10 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("dtype", [torch.float16])#param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", ['thd'])#qkv_formats) +@pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats) @pytest.mark.parametrize("is_paged", [False])#, True]) @pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("is_cuda_graph", [False])#, True]) +@pytest.mark.parametrize("is_cuda_graph", [True])#False])#, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() logger = logging.getLogger("test_paged_attn") @@ -353,6 +353,7 @@ def generate_data( ) for _ in range(3) ] + print(aa[0].shape, aa[0][8,0,:4]) #aa.extend([model_config.sequence_length, model_config.sequence_length]) return aa @@ -507,6 +508,9 @@ def gen_cu( cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) + print('qkv_format' ,qkv_format, cu_seqlens_q, cu_seqlens_kv) + print("q[1, 8:10, :2, :2]", q[1, 8:10, :2, :2]) + print("inc_q[18:20, :2, :2]", incremental_q[18:20, :2, :2]) step_dict = OrderedDict( zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 85bc4fa6ac..7b44517773 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -29,6 +29,7 @@ fused_attn_fwd, fused_attn_bwd, QKVLayout, + QKVFormat, AttnBiasType, AttnMaskType, FusedAttnBackend, @@ -1156,8 +1157,9 @@ def __init__( self.head_dim_q = head_dim_q self.max_ctx_len = max_ctx_len + self.input_qkv_format = "bshd" # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache - self.inference_qkv_format = "bshd" + self.cache_qkv_format = "bshd" # layer numbers that we have kv cache for #self.layer_numbers = [] @@ -1226,6 +1228,8 @@ def allocate_memory(self, layer_number: int, qkv_format: str): dtype=self.dtype, device=torch.cuda.current_device(), ) + self.q_dummy = torch.Tensor().to(dtype=self.dtype, device="cuda") + self.batch_indices = torch.Tensor().to(dtype=torch.int32, device="cuda") self.cu_seqlens_q = torch.zeros( self.max_batch_size + 1, dtype=torch.int32, @@ -1459,18 +1463,46 @@ def update_cache( #actual_batch_size = len(self.step_dict) #seqlens_q = list(self.step_dict.values()) #cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + print('qkv_foramt', qkv_format) + #print('qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) + #print('qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) seqlens_q = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] batch_size = len(seqlens_q) if qkv_format == "bshd": q = q.contiguous() if qkv_format == "sbhd": q = q.transpose(0, 1).contiguous() + max_seqlen_q = q.shape[1] if qkv_format == "thd": + #print('---qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) + #print('---qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) q_buffer = self.q_buffer[layer_number] - #for i in range(actual_batch_size): - for i in range(batch_size): - q_buffer[i, : seqlens_q[i], :, :] = q[self.cu_seqlens_q[i] : self.cu_seqlens_q[i + 1], :, :] - q = q_buffer + ##for i in range(actual_batch_size): + #for i in range(batch_size): + # q_buffer[i, : seqlens_q[i], :, :] = q[self.cu_seqlens_q[i] : self.cu_seqlens_q[i + 1], :, :] + #q = q_buffer + max_seqlen_q = self.max_ctx_len + #max_seqlen_kv = self.max_seqlen_kv + step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] + max_ctx_len=q.shape[1] #64 + max_seq_len=q_buffer.shape[1] #64 #128 + max_ctx_tokens=q.shape[0] + max_tokens=q_buffer.shape[0]*q_buffer.shape[1] + #print('---++qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) + #print('---++qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) + #print(q_buffer.shape) + #print(step_lens, seq_lens, QKVFormat[qkv_format]) + #print(self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, + # max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + # TODO: batch_indices + tex.copy_to_kv_cache_non_paged( + q, self.q_dummy, q_buffer, self.q_dummy, + self.batch_indices, step_lens, seq_lens, + QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, + max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + #q = q_buffer + #print('qqqqqqqq', q_buffer.shape, q_buffer.dtype, q_buffer[:2, 8:10, 0, :4]) #self.page_table = page_table #self.seq_ids = list(self.cache_manager.sequences.keys()) @@ -1506,7 +1538,7 @@ def update_cache( # k_cache and v_cache are in InferenceParams.qkv_format format # return k_cache, v_cache, page_table - return q, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv + return q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, max_seqlen_q, self.max_seqlen_kv @torch.no_grad() @@ -7806,13 +7838,21 @@ def forward( assert all( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" + if qkv_format == "sbhd": + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + batch_size = query_layer.shape[1] + else: + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv + batch_size = query_layer.shape[0] page_table = None if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" # remember original format for output purposes - orig_qkv_format = qkv_format + inference_params.input_qkv_format = qkv_format # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference @@ -7844,15 +7884,17 @@ def forward( # update KV cache and return the full key/value tensors # full key/value tensors are in inference_params.qkv_format format - query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv = inference_params.update_cache( + print('query_layer',query_layer.shape, query_layer.dtype) + #print('query_layer', query_layer[8,0,:4]) + query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = inference_params.update_cache( self.layer_number, query_layer, key_layer, value_layer, qkv_format, ) - print('cu_seqlens_q',cu_seqlens_q) - print('cu_seqlens_kv',cu_seqlens_kv) + #print('cu_seqlens_q',cu_seqlens_q) + #print('cu_seqlens_kv',cu_seqlens_kv) # update cu_seqlens tensors #if inference_params.is_cuda_graph: @@ -7863,7 +7905,7 @@ def forward( # query tensor is now in inference_params.qkv_format #qkv_format = target_qkv_format - qkv_format = inference_params.inference_qkv_format + qkv_format = inference_params.cache_qkv_format cp_size = 1 if isinstance(self.cp_group, dist_group_type): @@ -7874,14 +7916,6 @@ def forward( context_parallel = cp_size > 1 if qkv_format in ["sbhd", "bshd"]: - if qkv_format == "sbhd": - max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[1] - else: - max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[0] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size if cu_seqlens_q is None or cu_seqlens_kv is None: @@ -8149,8 +8183,8 @@ def forward( quantizers=self.quantizers, ) print('ooooooooooo ',output.shape) - print(output[1,9,:4]) - print(output[1,10,:4]) + #print(output[1,9,:4]) + #print(output[1,10,:4]) from .cpu_offload import CPUOffloadEnabled @@ -8198,12 +8232,12 @@ def forward( batch_size = len(inference_params.step_dict) step_lens = list(inference_params.step_dict.values()) max_seqlen_q = max(list(inference_params.step_dict.values())) - print('xxxxxxxxx ', batch_size, step_lens, max_seqlen_q, inference_params.step_dict, orig_qkv_format) - if orig_qkv_format == "bshd": + print('xxxxxxxxx ', batch_size, step_lens, max_seqlen_q, inference_params.step_dict, inference_params.input_qkv_format) + if inference_params.input_qkv_format == "bshd": output = output[:batch_size, :max_seqlen_q].contiguous() - if orig_qkv_format == "sbhd": + if inference_params.input_qkv_format == "sbhd": output = output[:batch_size, :max_seqlen_q].transpose(0, 1).contiguous() - if orig_qkv_format == "thd": + if inference_params.input_qkv_format == "thd": packed_output = torch.Tensor().to(dtype=output.dtype, device=output.device) for i in range(batch_size): packed_output = torch.cat([packed_output, output[i, : step_lens[i]]], dim=0) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 85497fc1fb..1d9a9d96ce 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -41,7 +41,7 @@ void copy_to_kv_cache_non_paged( torch::Tensor step_lens, torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, - int h, int d, + int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int max_ctx_tokens, int max_tokens); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 513dc78d26..601ebfc5cf 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -21,16 +21,87 @@ __global__ void copy_to_kv_cache_non_paged_kernel( int* step_lens, int* seq_lens, NVTE_QKV_Format qkv_format, - int h, int d, + int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int max_ctx_tokens, int max_tokens) { // new_k, new_v: qkv_format; k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h * d; - int new_token_offset = batch_idx * max_ctx_len * h * d; - int cache_offset = batch_idx * max_seq_len * h * d + (seq_lens[batch_idx] - step_lens[batch_idx]) * h * d; + int num_elts_k = step_lens[batch_idx] * h_kv * d_k; + int num_elts_v = step_lens[batch_idx] * h_kv * d_v; + int new_token_offset_k = batch_idx * max_ctx_len * h_kv * d_k; + int new_token_offset_v = batch_idx * max_ctx_len * h_kv * d_v; + int cache_offset_k = batch_idx * max_seq_len * h_kv * d_k + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_k; + int cache_offset_v = batch_idx * max_seq_len * h_kv * d_v + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_v; + + scalar_t* new_k_token = new_k + new_token_offset_k; + scalar_t* k_cache_token = k_cache + cache_offset_k; + scalar_t* new_v_token = new_v + new_token_offset_v; + scalar_t* v_cache_token = v_cache + cache_offset_v; + + for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { + *(k_cache_token + i) = *(new_k_token + i); + } + for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { + *(v_cache_token + i) = *(new_v_token + i); + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + for (int j = 0; j < h_kv * d_k; j ++) { + *(k_cache + (cache_offset + i) * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k +j); + } + for (int j = 0; j < h_kv * d_v; j ++) { + *(v_cache + (cache_offset + i) * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v +j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + // no padding between sequences in new_k and new_v + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts_k = step_lens[batch_idx] * h_kv * d_k; + int num_elts_v = step_lens[batch_idx] * h_kv * d_v; + int new_token_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + new_token_offset += step_lens[t]; + } + int cache_offset = batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]; + + scalar_t* new_k_token = new_k + new_token_offset * h_kv * d_k; + scalar_t* k_cache_token = k_cache + cache_offset * h_kv * d_k; + scalar_t* new_v_token = new_v + new_token_offset * h_kv * d_v; + scalar_t* v_cache_token = v_cache + cache_offset * h_kv * d_v; + + for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { + *(k_cache_token + i) = *(new_k_token + i); + } + for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { + *(v_cache_token + i) = *(new_v_token + i); + } + } + } +} +template +__global__ void copy_to_kv_cache_non_paged_kernel_same_d( + scalar_t* new_k, scalar_t* new_v, + scalar_t* k_cache, scalar_t* v_cache, + int* batch_indices, + int* step_lens, + int* seq_lens, + NVTE_QKV_Format qkv_format, + int h_kv, int d_kv, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // batch_indices, step_lens, seq_lens: [b + 1] + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h_kv * d_kv; + int new_token_offset = batch_idx * max_ctx_len * h_kv * d_kv; + int cache_offset = batch_idx * max_seq_len * h_kv * d_kv + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; scalar_t* new_k_token = new_k + new_token_offset; scalar_t* k_cache_token = k_cache + cache_offset; @@ -46,22 +117,22 @@ __global__ void copy_to_kv_cache_non_paged_kernel( for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h * d; j ++) { - *(k_cache + (cache_offset + i) * h * d + j) = *(new_k + (i * b + batch_idx) * h * d +j); - *(v_cache + (cache_offset + i) * h * d + j) = *(new_v + (i * b + batch_idx) * h * d +j); + for (int j = 0; j < h_kv * d_kv; j ++) { + *(k_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_k + (i * b + batch_idx) * h_kv * d_kv +j); + *(v_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_v + (i * b + batch_idx) * h_kv * d_kv +j); } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { // no padding between sequences in new_k and new_v for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h * d; + int num_elts = step_lens[batch_idx] * h_kv * d_kv; int new_token_offset = 0; for (int t = 0; t < batch_idx; t ++) { new_token_offset += step_lens[t]; } - new_token_offset = new_token_offset * h * d; - int cache_offset = batch_idx * max_seq_len * h * d + (seq_lens[batch_idx] - step_lens[batch_idx]) * h * d; + new_token_offset = new_token_offset * h_kv * d_kv; + int cache_offset = (batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; scalar_t* new_k_token = new_k + new_token_offset; scalar_t* k_cache_token = k_cache + cache_offset; @@ -76,6 +147,61 @@ __global__ void copy_to_kv_cache_non_paged_kernel( } } template +__global__ void copy_to_kv_cache_non_paged_kernel_q( + scalar_t* new_k, + scalar_t* k_cache, + int* batch_indices, + int* step_lens, + int* seq_lens, + NVTE_QKV_Format qkv_format, + int h_kv, int d_kv, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // batch_indices, step_lens, seq_lens: [b + 1] + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h_kv * d_kv; + int new_token_offset = batch_idx * max_ctx_len * h_kv * d_kv; + int cache_offset = batch_idx * max_seq_len * h_kv * d_kv + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; + + scalar_t* new_k_token = new_k + new_token_offset; + scalar_t* k_cache_token = k_cache + cache_offset; + + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(k_cache_token + i) = *(new_k_token + i); + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + for (int j = 0; j < h_kv * d_kv; j ++) { + *(k_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_k + (i * b + batch_idx) * h_kv * d_kv +j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + // no padding between sequences in new_k and new_v + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h_kv * d_kv; + int new_token_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + new_token_offset += step_lens[t]; + } + new_token_offset = new_token_offset * h_kv * d_kv; + int cache_offset = (batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; + + scalar_t* new_k_token = new_k + new_token_offset; + scalar_t* k_cache_token = k_cache + cache_offset; + + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(k_cache_token + i) = *(new_k_token + i); + } + } + } +} +template void copy_to_kv_cache_non_paged_launcher( torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, @@ -83,9 +209,10 @@ void copy_to_kv_cache_non_paged_launcher( torch::Tensor step_lens, torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, - int h, int d, + int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int max_ctx_tokens, int max_tokens) { + if (new_v.data_ptr() != nullptr && d_k != d_v) { copy_to_kv_cache_non_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(new_v.data_ptr()), @@ -94,7 +221,28 @@ void copy_to_kv_cache_non_paged_launcher( batch_indices.data_ptr(), step_lens.data_ptr(), seq_lens.data_ptr(), - qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + } + if (new_v.data_ptr() != nullptr && d_k == d_v) { + copy_to_kv_cache_non_paged_kernel_same_d<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + batch_indices.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + } + if (new_v.data_ptr() == nullptr) { + copy_to_kv_cache_non_paged_kernel_q<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + batch_indices.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + } } void copy_to_kv_cache_non_paged( @@ -104,19 +252,19 @@ void copy_to_kv_cache_non_paged( torch::Tensor step_lens, torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, - int h, int d, + int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int max_ctx_tokens, int max_tokens) { if (k_cache.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); } else if (k_cache.scalar_type() == at::ScalarType::Float) { using dtype = float; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); } else { NVTE_ERROR("Unsupported dtype.\n"); } diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 92e20b2340..68bad79ed3 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -252,6 +252,10 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument for module in func.modules(): hook = module.register_forward_hook(hook_fn) hooks.append(hook) + print(len(args), [x.shape for x in args]) + print(len(args), [x.dtype for x in args]) + print(args[0][8,0,:4]) + print(kwargs) outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: hook.remove() diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 919ca8eeaf..ab3db9e7af 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -178,9 +178,9 @@ def step( k_cache, v_cache = self.cache[layer_number] step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - h=16 - d=64 - b=4 + #h=self.num_heads #16 + #d=self.head_dim_k #64 + #b=self.max_batch_size #4 max_ctx_len=k.shape[1] #64 max_seq_len=k_cache.shape[1] #64 #128 max_ctx_tokens=k.shape[0] @@ -189,7 +189,12 @@ def step( #print('step_lens ', step_lens) #print('seq_lens ', seq_lens) #print('self.batch_indices ', self.batch_indices) - tex.copy_to_kv_cache_non_paged(k, v, k_cache, v_cache, self.batch_indices, step_lens, seq_lens, QKVFormat[qkv_format], h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + print('lensss ', max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + tex.copy_to_kv_cache_non_paged( + k, v, k_cache, v_cache, + self.batch_indices, step_lens, seq_lens, + QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, self.max_batch_size, + max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) return k_cache, v_cache, None # #prev_batch_size = len(self.sequences) From f0d22ca12f574233053da20516997e45d99eb65c Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 13 Feb 2025 09:55:38 +0800 Subject: [PATCH 074/239] Fix a bug for D being nullptr in grouped gemm (#1475) * fix a bug for at::from_blob with nullptr Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix a bug for non-TN Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 36 ++++++++++++------- .../pytorch/csrc/extensions/gemm.cpp | 8 +++-- .../pytorch/module/grouped_linear.py | 3 +- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 2401f3ca95..22735c5292 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2131,21 +2131,30 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output + B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = False + single_output = True elif layout == "NN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output - out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = True + single_output = True else: # layout == "NT" - A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output + A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [o.clone() for o in out] grad = True + single_output = False - out_ref = [o.clone() for o in out] for i in range(z): general_gemm( A[i], @@ -2157,17 +2166,20 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): layout=layout, out=out_ref[i], ) + if single_output: + out_ref = [torch.cat(out_ref)] general_grouped_gemm( A, - list(B), - list(out), + B, + out, dtype, get_multi_stream_cublas_workspace(), - m_splits=[k] * n, # TODO, not sure + m_splits=m_splits, grad=grad, accumulate=accumulate, layout=layout, + single_output=single_output, ) # should be bit-wise match @@ -2190,7 +2202,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): pytest.skip(reason_for_no_fp8) z, m, k, n = shape - m_splits = m // z + m_splits = [m // z] * z dtype = torch.bfloat16 A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight @@ -2242,7 +2254,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): out, dtype, get_multi_stream_cublas_workspace(), - m_splits=[k] * m_splits, + m_splits=m_splits, accumulate=accumulate, ) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b044c9f604..54bd52f136 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -336,9 +336,13 @@ std::optional> te_general_grouped_gemm( auto dtype = GetATenDType(D_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); if (single_output) { - out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + if (output_data_ptr == nullptr) { + out_tensor = at::empty(D_shape, opts); + } else { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + } char* char_ptr = reinterpret_cast(output_data_ptr); - char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size(); + char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); output_data_ptr = reinterpret_cast(char_ptr); D_vectors.emplace_back(out_tensor); } else { diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2f9de58984..cab8dff7c2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -269,9 +269,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], general_grouped_gemm( weights, grad_output, - torch.split(dgrad, ctx.m_splits), + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), + single_output=True, layout="NN", m_splits=ctx.m_splits, grad=True, From 54ae0c775eb3aafaa96b0b0dbe7c8abec002ce1d Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Thu, 13 Feb 2025 16:27:09 -0800 Subject: [PATCH 075/239] WIP: non-paged, thd, CG Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 12 +- transformer_engine/pytorch/attention.py | 69 ++++++++---- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/attention.cu | 105 +++++++++++++++--- transformer_engine/pytorch/graph.py | 2 +- .../pytorch/kv_cache_manager_non_paged.py | 6 +- 6 files changed, 150 insertions(+), 47 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index dbae01ad81..2b6259886d 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -210,7 +210,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats) @pytest.mark.parametrize("is_paged", [False])#, True]) @pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("is_cuda_graph", [True])#False])#, True]) +@pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() logger = logging.getLogger("test_paged_attn") @@ -353,7 +353,7 @@ def generate_data( ) for _ in range(3) ] - print(aa[0].shape, aa[0][8,0,:4]) + #print(aa[0].shape, aa[0][8,0,:4]) #aa.extend([model_config.sequence_length, model_config.sequence_length]) return aa @@ -397,7 +397,7 @@ def gen_cu( model = make_graphed_callables( model, generate_data(config, dtype, warmup=True, qkv_format=qkv_format), - num_warmup_iters=10, + num_warmup_iters=3, #10, fp8_enabled=False, #sample_kwargs={"qkv_format":"thd"}, sample_kwargs=gen_cu(config, dtype), @@ -460,6 +460,7 @@ def gen_cu( dim=0, ) if is_cuda_graph: + print('incremental qkv shapes ', [x.shape for x in [incremental_q, incremental_k, incremental_v]]) incremental_q = torch.cat([incremental_q, torch.zeros([max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], dtype=dtype, device=incremental_q.device)], dim=0) incremental_k = torch.cat([incremental_k, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_k.device)], dim=0) incremental_v = torch.cat([incremental_v, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_v.device)], dim=0) @@ -509,8 +510,8 @@ def gen_cu( cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) print('qkv_format' ,qkv_format, cu_seqlens_q, cu_seqlens_kv) - print("q[1, 8:10, :2, :2]", q[1, 8:10, :2, :2]) - print("inc_q[18:20, :2, :2]", incremental_q[18:20, :2, :2]) + #print("q[1, 8:10, :2, :2]", q[1, 8:10, :2, :2]) + #print("inc_q[18:20, :2, :2]", incremental_q[18:20, :2, :2]) step_dict = OrderedDict( zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) @@ -593,6 +594,7 @@ def gen_cu( rtol=tols[dtype], ) if qkv_format == "thd": + print('iiii ', i, cu_seqlens_q, sim.t_total_lens) print('thd ', seq, sim.t_total_lens[i], cu_seqlens_q[i + 1]) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(line_output[cu_seqlens_q[i + 1] - 1, :4]) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7b44517773..637dfe38bc 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1219,6 +1219,7 @@ def allocate_memory(self, layer_number: int, qkv_format: str): self.cache_manager.allocate_memory(layer_number) if qkv_format == 'thd': #self.is_cuda_graph: #self.max_seqlen_q = self.max_seqlen_kv + self.q_orig = {} self.q_buffer = {} self.q_buffer[layer_number] = torch.zeros( self.max_batch_size, @@ -1463,44 +1464,52 @@ def update_cache( #actual_batch_size = len(self.step_dict) #seqlens_q = list(self.step_dict.values()) #cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] - print('qkv_foramt', qkv_format) + #print('qkv_foramt', qkv_format) #print('qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) #print('qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) seqlens_q = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] batch_size = len(seqlens_q) + self.q_orig[layer_number] = q if qkv_format == "bshd": - q = q.contiguous() + q_buffer = q.contiguous() + max_seqlen_q = q_buffer.shape[1] if qkv_format == "sbhd": - q = q.transpose(0, 1).contiguous() - max_seqlen_q = q.shape[1] + q_buffer = q.transpose(0, 1).contiguous() + max_seqlen_q = q_buffer.shape[1] if qkv_format == "thd": #print('---qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) #print('---qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) q_buffer = self.q_buffer[layer_number] + #q_buffer_copy = self.q_buffer[layer_number].clone() ##for i in range(actual_batch_size): #for i in range(batch_size): # q_buffer[i, : seqlens_q[i], :, :] = q[self.cu_seqlens_q[i] : self.cu_seqlens_q[i + 1], :, :] - #q = q_buffer + ##q = q_buffer max_seqlen_q = self.max_ctx_len + #max_seqlen_kv = self.max_seqlen_kv step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] - max_ctx_len=q.shape[1] #64 - max_seq_len=q_buffer.shape[1] #64 #128 - max_ctx_tokens=q.shape[0] - max_tokens=q_buffer.shape[0]*q_buffer.shape[1] + #seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] + max_ctx_len=q.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + max_seq_len=self.max_ctx_len #q_buffer.shape[1] #64 #128 + #max_ctx_tokens=q.shape[0] + #max_tokens=q_buffer.shape[0]*q_buffer.shape[1] #print('---++qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) #print('---++qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) #print(q_buffer.shape) - #print(step_lens, seq_lens, QKVFormat[qkv_format]) - #print(self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, - # max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + #print(self.cu_seqlens_q, self.cu_seqlens_kv, step_lens, seq_lens, QKVFormat[qkv_format]) + print('q xxxxxxxxxxxx ',self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, + max_ctx_len, max_seq_len)#, max_ctx_tokens, max_tokens) # TODO: batch_indices tex.copy_to_kv_cache_non_paged( q, self.q_dummy, q_buffer, self.q_dummy, - self.batch_indices, step_lens, seq_lens, + self.batch_indices, step_lens, step_lens, QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, - max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) + #q = q_buffer + #q_buffer = q_buffer_copy + #torch.save(q_buffer, 'q_buffer.pt') + #torch.save(q_buffer_copy, 'q_buffer_copy.pt') #q = q_buffer #print('qqqqqqqq', q_buffer.shape, q_buffer.dtype, q_buffer[:2, 8:10, 0, :4]) @@ -8238,10 +8247,32 @@ def forward( if inference_params.input_qkv_format == "sbhd": output = output[:batch_size, :max_seqlen_q].transpose(0, 1).contiguous() if inference_params.input_qkv_format == "thd": - packed_output = torch.Tensor().to(dtype=output.dtype, device=output.device) - for i in range(batch_size): - packed_output = torch.cat([packed_output, output[i, : step_lens[i]]], dim=0) - output = packed_output.contiguous() + output_buffer = inference_params.q_orig[self.layer_number] + #packed_output = torch.Tensor().to(dtype=output.dtype, device=output.device) + #for i in range(batch_size): + # packed_output = torch.cat([packed_output, output[i, : step_lens[i]]], dim=0) + #output = packed_output.contiguous() + + #max_seqlen_kv = self.max_seqlen_kv + step_lens = inference_params.cu_seqlens_q[1:] - inference_params.cu_seqlens_q[:-1] + #seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] + max_ctx_len=1 #output.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + max_seq_len=inference_params.max_ctx_len #q_buffer.shape[1] #64 #128 + #max_ctx_tokens=q.shape[0] + #max_tokens=q_buffer.shape[0]*q_buffer.shape[1] + #print('---++qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) + #print('---++qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) + #print(q_buffer.shape) + #print(self.cu_seqlens_q, self.cu_seqlens_kv, step_lens, seq_lens, QKVFormat[qkv_format]) + #print('o xxxxxxxxxxxx ',step_lens, #self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, + # max_ctx_len, max_seq_len, output.shape, output_buffer.shape)#, max_ctx_tokens, max_tokens) + # TODO: batch_indices + tex.copy_to_kv_cache_non_paged( + inference_params.q_dummy, output, inference_params.q_dummy, output_buffer, + inference_params.batch_indices, step_lens, step_lens, + QKVFormat[qkv_format], inference_params.num_heads_q, inference_params.head_dim_q, inference_params.head_dim_q, inference_params.max_batch_size, + max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) + output = output_buffer.view(output_buffer.shape[0], -1) return output diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1d9a9d96ce..ba2f8e4530 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -42,8 +42,7 @@ void copy_to_kv_cache_non_paged( torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens); + int b, int max_ctx_len, int max_seq_len); NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 601ebfc5cf..1643b39c57 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -22,8 +22,7 @@ __global__ void copy_to_kv_cache_non_paged_kernel( int* seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens) { + int b, int max_ctx_len, int max_seq_len) { // new_k, new_v: qkv_format; k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { @@ -93,8 +92,7 @@ __global__ void copy_to_kv_cache_non_paged_kernel_same_d( int* seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_kv, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens) { + int b, int max_ctx_len, int max_seq_len) { // new_k, new_v: qkv_format; k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { @@ -155,8 +153,7 @@ __global__ void copy_to_kv_cache_non_paged_kernel_q( int* seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_kv, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens) { + int b, int max_ctx_len, int max_seq_len) { // new_k, new_v: qkv_format; k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { @@ -202,6 +199,69 @@ __global__ void copy_to_kv_cache_non_paged_kernel_q( } } template +__global__ void copy_to_kv_cache_non_paged_kernel_o( + scalar_t* output, // output in bshd + scalar_t* output_buffer, // output in thd + //int* batch_indices, + int* step_lens, + //int* seq_lens, + NVTE_QKV_Format qkv_format, + int h_q, int d_q, + int b, int max_seq_len) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // batch_indices, step_lens, seq_lens: [b + 1] + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h_q * d_q; + int output_offset = batch_idx * max_seq_len * h_q * d_q; + int output_buffer_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + output_buffer_offset += step_lens[t]; + } + output_buffer_offset = output_buffer_offset * h_q * d_q; + scalar_t* output_token = output + output_offset; + scalar_t* output_buffer_token = output_buffer + output_buffer_offset; + + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + if (batch_idx < 2 && i < 3) { + printf("h_q %d d_q %d, b %d t %d, output_offset %d output_buffer_offset %d, output_buffer_token + i %p output_token + i %p\n",h_q, d_q,batch_idx, i, output_offset, output_buffer_offset, output_buffer_token + i, output_token + i); + } + *(output_buffer_token + i) = *(output_token + i); + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int output_buffer_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + output_buffer_offset += step_lens[t]; + } + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + for (int j = 0; j < h_q * d_q; j ++) { + *(output_buffer + (output_buffer_offset + i) * h_q * d_q + j) = *(output + (i * b + batch_idx) * h_q * d_q +j); + } + } + } +// } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { +// // no padding between sequences in new_k and new_v +// for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { +// int num_elts = step_lens[batch_idx] * h_kv * d_kv; +// int new_token_offset = 0; +// for (int t = 0; t < batch_idx; t ++) { +// new_token_offset += step_lens[t]; +// } +// new_token_offset = new_token_offset * h_kv * d_kv; +// int cache_offset = (batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; +// +// scalar_t* new_k_token = new_k + new_token_offset; +// scalar_t* k_cache_token = k_cache + cache_offset; +// +// for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { +// *(k_cache_token + i) = *(new_k_token + i); +// } +// } + } +} +template void copy_to_kv_cache_non_paged_launcher( torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, @@ -210,9 +270,9 @@ void copy_to_kv_cache_non_paged_launcher( torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens) { + int b, int max_ctx_len, int max_seq_len) { if (new_v.data_ptr() != nullptr && d_k != d_v) { + printf("-------- 1 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); copy_to_kv_cache_non_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(new_v.data_ptr()), @@ -221,9 +281,10 @@ void copy_to_kv_cache_non_paged_launcher( batch_indices.data_ptr(), step_lens.data_ptr(), seq_lens.data_ptr(), - qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); } - if (new_v.data_ptr() != nullptr && d_k == d_v) { + if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && d_k == d_v) { + printf("-------- 2 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); copy_to_kv_cache_non_paged_kernel_same_d<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(new_v.data_ptr()), @@ -232,16 +293,27 @@ void copy_to_kv_cache_non_paged_launcher( batch_indices.data_ptr(), step_lens.data_ptr(), seq_lens.data_ptr(), - qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); } if (new_v.data_ptr() == nullptr) { + printf("-------- 3 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); copy_to_kv_cache_non_paged_kernel_q<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(k_cache.data_ptr()), batch_indices.data_ptr(), step_lens.data_ptr(), seq_lens.data_ptr(), - qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); + } + if (new_k.data_ptr() == nullptr) { + printf("-------- 4 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); + copy_to_kv_cache_non_paged_kernel_o<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + //batch_indices.data_ptr(), + step_lens.data_ptr(), + //seq_lens.data_ptr(), + qkv_format, h_kv, d_k, b, max_seq_len); } } @@ -253,18 +325,17 @@ void copy_to_kv_cache_non_paged( torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens) { + int b, int max_ctx_len, int max_seq_len) { if (k_cache.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); } else if (k_cache.scalar_type() == at::ScalarType::Float) { using dtype = float; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens); + copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); } else { NVTE_ERROR("Unsupported dtype.\n"); } diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 68bad79ed3..3d252cc1e7 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -254,7 +254,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument hooks.append(hook) print(len(args), [x.shape for x in args]) print(len(args), [x.dtype for x in args]) - print(args[0][8,0,:4]) + #print(args[0][8,0,:4]) print(kwargs) outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index ab3db9e7af..767d6dbc0f 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -181,8 +181,8 @@ def step( #h=self.num_heads #16 #d=self.head_dim_k #64 #b=self.max_batch_size #4 - max_ctx_len=k.shape[1] #64 - max_seq_len=k_cache.shape[1] #64 #128 + max_ctx_len=k.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + max_seq_len=self.max_seqlen #k_cache.shape[1] #64 #128 max_ctx_tokens=k.shape[0] max_tokens=k_cache.shape[0]*k_cache.shape[1] print('kv shapes ', [x.shape for x in [k, v, k_cache, v_cache]]) @@ -194,7 +194,7 @@ def step( k, v, k_cache, v_cache, self.batch_indices, step_lens, seq_lens, QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, self.max_batch_size, - max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + max_ctx_len, max_seq_len)#, max_ctx_tokens, max_tokens) return k_cache, v_cache, None # #prev_batch_size = len(self.sequences) From 24e4f955d9de38402585bd0322c6f1e5dabd96cf Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 13 Feb 2025 19:25:03 -0800 Subject: [PATCH 076/239] [JAX] Flax params initialization with weight_dtype (#1481) * initialization with weight_dtype Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/flax/module.py | 104 ++++++++++++++------- transformer_engine/jax/flax/transformer.py | 48 ++++++++-- 2 files changed, 109 insertions(+), 43 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 4c46eafb4c..2190c6df6c 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -8,8 +8,8 @@ import operator from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp from flax import linen as nn from flax.linen import partitioning as nn_partitioning from jax import lax @@ -57,14 +57,18 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga def _create_layernorm_parameters( - layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype + layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype ): - scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) + scale = nn_partitioning.param_with_axes( + "scale", scale_init, shape, weight_dtype, axes=scale_axes + ) scale = scale.astype(dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == "layernorm": - bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) + bias = nn_partitioning.param_with_axes( + "ln_bias", bias_init, shape, weight_dtype, axes=bias_axes + ) bias = bias.astype(dtype) else: assert layernorm_type == "rmsnorm" @@ -256,8 +260,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - the data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -272,6 +278,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): @@ -307,6 +314,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: self.bias_init, self.bias_axes, self.dtype, + self.weight_dtype, ) return layernorm( x, @@ -399,8 +407,10 @@ class DenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -418,12 +428,13 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype ) super().__post_init__() @@ -452,13 +463,13 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes ) kernel = kernel.astype(self.dtype) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes ) bias = bias.astype(self.dtype) else: @@ -489,7 +500,7 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.dtype, + self.weight_dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) @@ -501,7 +512,7 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=lora_b_kernel_axes, ) lora_b_kernel = lora_b_kernel.astype(self.dtype) @@ -594,8 +605,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -625,6 +638,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None @@ -633,7 +647,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, + "fan_in", + "truncated_normal", + dtype=self.weight_dtype, ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -683,6 +700,7 @@ def __call__(self, inputs: Array) -> Array: self.ln_bias_init, self.ln_bias_axes, self.dtype, + self.weight_dtype, ) if not fuse_layernorm: @@ -712,7 +730,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes ) kernel = kernel.astype(self.dtype) @@ -757,7 +775,7 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.dtype, + self.weight_dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) @@ -769,7 +787,7 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=lora_b_kernel_axes, ) lora_b_kernel = lora_b_kernel.astype(self.dtype) @@ -781,7 +799,7 @@ def __call__(self, inputs: Array) -> Array: bias = None if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes ) bias = bias.astype(self.dtype) @@ -896,8 +914,10 @@ class LayerNormMLP(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -930,6 +950,7 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None @@ -938,7 +959,7 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -1015,6 +1036,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: self.ln_bias_init, self.ln_bias_axes, self.dtype, + self.weight_dtype, ) if not fuse_layernorm: @@ -1061,7 +1083,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, kernel_1_each_shape, - self.dtype, + self.weight_dtype, axes=self.kernel_axes_1, ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) @@ -1074,7 +1096,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_kernel", self.kernel_init, kernel_2_param_shape, - self.dtype, + self.weight_dtype, axes=self.kernel_axes_2, ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) @@ -1090,13 +1112,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: bias_1_shape = intermediate_dim bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1 + "wi_bias", + self.bias_init, + bias_1_shape, + self.weight_dtype, + axes=self.bias_axes_1, ) bias_1 = bias_1.astype(self.dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2 + "wo_bias", + self.bias_init, + bias_2_shape, + self.weight_dtype, + axes=self.bias_axes_2, ) bias_2 = bias_2.astype(self.dtype) else: @@ -1165,7 +1195,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, wi_lora_a_kernel_init_each_shape, - self.dtype, + self.weight_dtype, axes=wi_lora_a_kernel_axes, ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) @@ -1181,7 +1211,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_lora_b_kernel", nn.initializers.zeros, wi_lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=wi_lora_b_kernel_axes, ) wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) @@ -1198,7 +1228,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_1 = None if self.use_bias: bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1 + "wi_bias", + self.bias_init, + intermediate_dim, + self.weight_dtype, + axes=self.bias_axes_1, ) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape bias_1 = bias_1.astype(self.dtype) @@ -1240,7 +1274,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_a_kernel", self.kernel_init, wo_lora_a_kernel_shape, - self.dtype, + self.weight_dtype, axes=wo_lora_a_kernel_axes, ) wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) @@ -1251,7 +1285,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_b_kernel", nn.initializers.zeros, wo_lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=wo_lora_b_kernel_axes, ) wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) @@ -1268,7 +1302,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2 = None if self.use_bias: bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2 + "wo_bias", + self.bias_init, + (hidden_size,), + self.weight_dtype, + axes=self.bias_axes_2, ) bias_2 = bias_2.astype(self.dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 89278f720b..6c96e7ba1a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -115,6 +115,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 float32_logits: bool = False scale_factor: Optional[float] = None transpose_batch_sequence: bool = True @@ -261,6 +262,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD scale_factor: Optional[float] = None transpose_batch_sequence: bool = False @@ -480,8 +482,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. """ head_dim: int @@ -491,6 +495,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods attn_mask_type: AttnMaskType = "causal" attn_bias_type: AttnBiasType = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 dropout_rng_name: str = "dropout" float32_logits: bool = False qkv_layout: str = "bshd_bshd_bshd" @@ -615,6 +620,7 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, + weight_dtype=self.weight_dtype, float32_logits=self.float32_logits, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, @@ -626,6 +632,7 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, + weight_dtype=self.weight_dtype, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, @@ -880,8 +887,10 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. fuse_qkv_params: bool, default = True If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for @@ -927,6 +936,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False @@ -977,7 +987,7 @@ def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.dtype + 1.0, "fan_in", "normal", self.weight_dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1105,6 +1115,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): dot_input_axes=inputs_logical_axes_no_sp, name="qkv", dtype=self.dtype, + weight_dtype=self.weight_dtype, )(inputs_q) qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_layout = QKVLayout.BS3HD @@ -1128,6 +1139,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1152,6 +1164,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, name="kv", dtype=self.dtype, + weight_dtype=self.weight_dtype, )(inputs_kv) kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") qkv_layout = QKVLayout.BSHD_BS2HD @@ -1169,6 +1182,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, ) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, @@ -1189,6 +1203,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1326,6 +1341,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): attn_bias_type=self.attn_bias_type, attention_dropout=self.attention_dropout, dtype=self.dtype, + weight_dtype=self.weight_dtype, dropout_rng_name=self.dropout_rng_name, float32_logits=self.float32_logits, qkv_layout=qkv_layout.name, @@ -1351,6 +1367,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, name="out", )(x) out = checkpoint_name(out, "out_proj") @@ -1379,7 +1396,9 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. """ num_buckets: int @@ -1388,6 +1407,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 @nn.compact def __call__(self, q_seqlen, k_seqlen, bidirectional=True): @@ -1440,7 +1460,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): "rel_embedding", self.embedding_init, (self.num_attention_heads, self.num_buckets), - self.dtype, + self.weight_dtype, axes=self.embedding_axes, ) @@ -1613,7 +1633,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. drop_path: float, default = 0.0 When > 0.0, applies stochastic depth per sample in the main path of the residual block. @@ -1666,6 +1688,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True transpose_batch_sequence: bool = False @@ -1677,11 +1700,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: self.mha_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.dtype + 1.0, "fan_in", "normal", dtype=self.weight_dtype ) if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1771,6 +1794,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, + weight_dtype=self.weight_dtype, embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), name="relpos_bias", ) @@ -1804,6 +1828,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): x, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, + weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -1882,6 +1907,7 @@ def hidden_dropout(x, deterministic): y, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, + weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -1947,6 +1973,7 @@ def hidden_dropout(x, deterministic): intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, dtype=self.dtype, + weight_dtype=self.weight_dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_init=self.mlp_kernel_init, @@ -1996,6 +2023,7 @@ def hidden_dropout(x, deterministic): bias_axes=(W_NO_SHARD_AXES,), transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, + weight_dtype=self.weight_dtype, name="output_layernorm", )(z) From e19b8281453ec1835319448bba93697fc8b0f537 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 14 Feb 2025 08:11:32 -0800 Subject: [PATCH 077/239] [JAX] Fixes for CI failures with the latest JAX (#1469) * fixes L1 test * fix test_multigpu_encoder * fixes for other multi-encoder tests * jax.extend.ffi to jax.ffi * initialization with float32 * add init_dtype as an optional arg to all modules * update use_scan query from xla flags * relax threshold for test_encoder fp8 * relax the tols --------- Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 6 +++--- examples/jax/encoder/test_multigpu_encoder.py | 2 +- .../encoder/test_multiprocessing_encoder.py | 4 ++-- .../jax/encoder/test_single_gpu_encoder.py | 2 +- qa/L1_jax_distributed_unittest/test.sh | 8 +------- tests/jax/test_distributed_fused_attn.py | 20 ++++++++++++++----- .../jax/cpp_extensions/activation.py | 2 +- .../jax/cpp_extensions/attention.py | 6 ++---- .../jax/cpp_extensions/custom_call.py | 7 +++---- .../jax/cpp_extensions/normalization.py | 2 +- .../jax/cpp_extensions/quantization.py | 2 +- .../jax/cpp_extensions/softmax.py | 2 +- .../jax/cpp_extensions/transpose.py | 2 +- transformer_engine/jax/flax/transformer.py | 7 ++++--- 14 files changed, 37 insertions(+), 35 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 918dfd8238..f02cc562b5 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -239,7 +239,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding @@ -447,7 +447,7 @@ def test_te_fp8(self): """Test Transformer Engine with FP8""" self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_sp(self): @@ -462,7 +462,7 @@ def test_te_fp8_sp(self): self.args.enable_sp = True self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.785 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index c0325d3e28..eb4a1d0afb 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -218,7 +218,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 7d2df77b7d..91186a15c4 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -320,7 +320,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding @@ -587,7 +587,7 @@ def test_te_bf16(self): def test_te_fp8(self): """Test Transformer Engine with FP8""" result = self.exec(True) - assert result[0] < 0.45 and result[1] > 0.79 + assert result[0] < 0.455 and result[1] > 0.79 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index b2439278ea..dd1997fe6f 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -334,7 +334,7 @@ def test_te_fp8(self): """Test Transformer Engine with FP8""" self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.79 if __name__ == "__main__": diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index deb0f93cec..e47aa15fbd 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -6,10 +6,4 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -# Skip ring attention tests since they need fixed environment vars -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn' - -# Test ring attention with and without scan loop -NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn -NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \ - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d7e015dbf7..898993f5d1 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. +import os +import pytest import jax import jax.numpy as jnp import numpy as np @@ -11,7 +13,7 @@ generate_context_parallel_configs, generate_collectives_count, ) -from transformer_engine.jax import fp8_autocast +from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, AttnBiasType, @@ -22,10 +24,7 @@ inverse_reorder_causal_load_balancing, CPStrategy, ) -from transformer_engine.jax.sharding import MeshResource -import pytest -from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat DTYPES = [jnp.bfloat16] @@ -355,6 +354,10 @@ def test_context_parallel_allgather_attn( CPStrategy.ALL_GATHER, ) + @pytest.mark.parametrize( + "use_scan", + [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], + ) def test_context_parallel_ring_attn( self, device_count, @@ -367,8 +370,14 @@ def test_context_parallel_ring_attn( dtype, qkv_layout, load_balanced, + use_scan, ): - return self.impl_test_context_parallel_attn( + if use_scan: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" + else: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" + + self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -381,6 +390,7 @@ def test_context_parallel_ring_attn( load_balanced, CPStrategy.RING, ) + del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] class TestReorderCausalLoadBalancing: diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 4a29fce2c4..076ec98aba 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -11,7 +11,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import NVTE_Activation_Type diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 5ec556ab34..1c32ef4cba 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -15,7 +15,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor @@ -1602,9 +1602,7 @@ def use_scanloop(): def truthy(val): return val.lower() in ["1", "true"] - x = use_scan and get_xla_flag( - "--xla_experimental_ignore_channel_id", default=False, cast=truthy - ) + x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy) return x def check_supported(self): diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 6739ac8bda..6f6c9962cf 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -5,9 +5,8 @@ from dataclasses import dataclass from enum import IntEnum +import jax from jax.interpreters import mlir -import jax.extend as jex - from transformer_engine import transformer_engine_jax from .misc import is_ffi_enabled @@ -30,11 +29,11 @@ class CustomCallAPIVersion(IntEnum): for _name, _value in transformer_engine_jax.registrations().items(): if _name.endswith("_ffi"): if is_ffi_enabled(): - jex.ffi.register_ffi_target( + jax.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value ) else: - jex.ffi.register_ffi_target( + jax.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d7512b0e70..1107dd3a0f 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -13,7 +13,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c3ea8cb7aa..2f29a64f18 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -9,7 +9,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 5c55dd3672..dba1f504da 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -12,7 +12,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index d07b6944fb..bb9b104e7e 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -11,7 +11,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 6c96e7ba1a..fbae73f131 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -150,8 +150,8 @@ def __call__( del self.scale_factor if self.float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) + query = query.astype(self.dtype) + key = key.astype(self.dtype) h_q, h_kv = query.shape[-2], key.shape[-2] # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # Therefore, we have to maintain two code paths. @@ -989,6 +989,7 @@ def __post_init__(self): self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "normal", self.weight_dtype ) + self.kernel_init = _kernel_init.astype(self.dtype) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -1281,7 +1282,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): f"expected query shape {expected_shape} instead got {query.shape}." ) - cur_index = cache_index.value + cur_index = cache_index.value.astype(jnp.int32) one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape) key = cached_key.value + key * one_hot_indices From 654c929014e3fcc2ce40298a84cfe8209466669a Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 14 Feb 2025 12:26:14 -0800 Subject: [PATCH 078/239] WIP: non-paged, CG Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 23 +++--- transformer_engine/pytorch/attention.py | 18 +++-- .../pytorch/csrc/extensions/attention.cu | 70 +++++++++++++++++-- .../pytorch/kv_cache_manager_non_paged.py | 15 +++- 4 files changed, 105 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 2b6259886d..3b740392f7 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -74,10 +74,10 @@ def __init__( self.seq_ids = torch.range(0, total_requests-1, dtype=torch.int32, device="cpu") # simulate context lengths in Uniform distribution - #self.context_lens = torch.randint( - # 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" - #) - self.context_lens = 10 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + self.context_lens = torch.randint( + 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" + ) + #self.context_lens = 10 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -85,10 +85,10 @@ def __init__( gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to( dtype=torch.int32, device="cpu" ) - #self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( - # dtype=torch.int32, device="cpu" - #) - self.gen_lens = 5 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( + dtype=torch.int32, device="cpu" + ) + #self.gen_lens = 5 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate arrival times in Poisson distribution if poisson_rate is None: @@ -207,7 +207,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("dtype", [torch.float16])#param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats) +@pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False])#, True]) @pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) @@ -587,6 +587,9 @@ def gen_cu( rtol=tols[dtype], ) if qkv_format == "sbhd": + print(i,seq, sim.t_total_lens[i], sim.step_lens[i]) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[sim.step_lens[i] - 1, i, :4]) torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], line_output[sim.step_lens[i] - 1, i, :], @@ -598,6 +601,8 @@ def gen_cu( print('thd ', seq, sim.t_total_lens[i], cu_seqlens_q[i + 1]) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(line_output[cu_seqlens_q[i + 1] - 1, :4]) + #print(line_output[cu_seqlens_q[1 + 1] - 1, :4]) + #print(line_output[cu_seqlens_q[2 + 1] - 1, :4]) torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], line_output[cu_seqlens_q[i + 1] - 1, :], diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 637dfe38bc..088ef8db65 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1258,9 +1258,13 @@ def prepare( actual_batch_size = len(self.step_dict) seqlens_q = list(self.step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * ( + self.max_batch_size - actual_batch_size + ) self.seq_lens = list(self.sequences.values()) - self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( + #self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( + self.cu_seqlens_q.copy_( torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") ) cu_seqlens_kv = [0] + [sum(self.seq_lens[:i]) for i in range(1, actual_batch_size + 1)] @@ -1469,7 +1473,6 @@ def update_cache( #print('qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) seqlens_q = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] batch_size = len(seqlens_q) - self.q_orig[layer_number] = q if qkv_format == "bshd": q_buffer = q.contiguous() max_seqlen_q = q_buffer.shape[1] @@ -1479,6 +1482,7 @@ def update_cache( if qkv_format == "thd": #print('---qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) #print('---qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) + self.q_orig[layer_number] = q q_buffer = self.q_buffer[layer_number] #q_buffer_copy = self.q_buffer[layer_number].clone() ##for i in range(actual_batch_size): @@ -7951,6 +7955,8 @@ def forward( ) print('max_seqlen_q ', max_seqlen_q) print('max_seqlen_kv ', max_seqlen_kv) + #print('cu_seqlens_q ', cu_seqlens_q) + #print('cu_seqlens_kv ', cu_seqlens_kv) if ( isinstance(query_layer, Float8Tensor) @@ -8241,7 +8247,10 @@ def forward( batch_size = len(inference_params.step_dict) step_lens = list(inference_params.step_dict.values()) max_seqlen_q = max(list(inference_params.step_dict.values())) - print('xxxxxxxxx ', batch_size, step_lens, max_seqlen_q, inference_params.step_dict, inference_params.input_qkv_format) + print('xxxxxxxxx ', batch_size, step_lens, max_seqlen_q, inference_params.step_dict, inference_params.input_qkv_format, output.shape) + #ooo = output.view(output.shape[:2], -1) + #print('output ', output[0,0,:4]) + #print('output ', output[1,0,:4]) if inference_params.input_qkv_format == "bshd": output = output[:batch_size, :max_seqlen_q].contiguous() if inference_params.input_qkv_format == "sbhd": @@ -8254,7 +8263,8 @@ def forward( #output = packed_output.contiguous() #max_seqlen_kv = self.max_seqlen_kv - step_lens = inference_params.cu_seqlens_q[1:] - inference_params.cu_seqlens_q[:-1] + #step_lens = inference_params.cu_seqlens_q[1:] - inference_params.cu_seqlens_q[:-1] + step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] #seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] max_ctx_len=1 #output.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 max_seq_len=inference_params.max_ctx_len #q_buffer.shape[1] #64 #128 diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 1643b39c57..9dbbab12ca 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -148,9 +148,9 @@ template __global__ void copy_to_kv_cache_non_paged_kernel_q( scalar_t* new_k, scalar_t* k_cache, - int* batch_indices, + //int* batch_indices, int* step_lens, - int* seq_lens, + //int* seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_kv, int b, int max_ctx_len, int max_seq_len) { @@ -160,7 +160,7 @@ __global__ void copy_to_kv_cache_non_paged_kernel_q( for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int num_elts = step_lens[batch_idx] * h_kv * d_kv; int new_token_offset = batch_idx * max_ctx_len * h_kv * d_kv; - int cache_offset = batch_idx * max_seq_len * h_kv * d_kv + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; + int cache_offset = batch_idx * max_seq_len * h_kv * d_kv; // + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; scalar_t* new_k_token = new_k + new_token_offset; scalar_t* k_cache_token = k_cache + cache_offset; @@ -171,7 +171,7 @@ __global__ void copy_to_kv_cache_non_paged_kernel_q( } } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); + int cache_offset = batch_idx * max_seq_len; // + (seq_lens[batch_idx] - step_lens[batch_idx]); for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { for (int j = 0; j < h_kv * d_kv; j ++) { *(k_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_k + (i * b + batch_idx) * h_kv * d_kv +j); @@ -187,7 +187,7 @@ __global__ void copy_to_kv_cache_non_paged_kernel_q( new_token_offset += step_lens[t]; } new_token_offset = new_token_offset * h_kv * d_kv; - int cache_offset = (batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; + int cache_offset = batch_idx * max_seq_len * h_kv * d_kv; // + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; scalar_t* new_k_token = new_k + new_token_offset; scalar_t* k_cache_token = k_cache + cache_offset; @@ -262,6 +262,48 @@ __global__ void copy_to_kv_cache_non_paged_kernel_o( } } template +__global__ void copy_to_kv_cache_non_paged_kernel_reindex( + scalar_t* k_cache, scalar_t* v_cache, + int* batch_indices, + int* step_lens, + int* seq_lens, + NVTE_QKV_Format qkv_format, + int h_kv, int d_k, int d_v, + int b, int max_ctx_len, int max_seq_len) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // batch_indices, step_lens, seq_lens: [b + 1] + // only support bshd as cache format + //if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + int actual_b = 0; + for (int i = 1; i < b; i++) { + if (batch_indices[i] > batch_indices[i-1]) { + actual_b = i+1; + } + } + for (int batch_idx = 0; batch_idx < actual_b; batch_idx ++) { + for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; token_idx += gridDim.x) { + //if (blockIdx.x < 2 && threadIdx.x < 3) { + // printf("bid %d tid %d, b %d actual_b %d, len %d, + // output_offset %d output_buffer_offset %d, output_buffer_token + i %p output_token + i %p\n",h_q, d_q, + // batch_idx, token_idx, output_offset, output_buffer_offset, output_buffer_token + i, output_token + i); + //} + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; + int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; + int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; + int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; + for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { + *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); + } + for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { + *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); + } + } + } + //} +} +template void copy_to_kv_cache_non_paged_launcher( torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, @@ -273,6 +315,13 @@ void copy_to_kv_cache_non_paged_launcher( int b, int max_ctx_len, int max_seq_len) { if (new_v.data_ptr() != nullptr && d_k != d_v) { printf("-------- 1 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); + copy_to_kv_cache_non_paged_kernel_reindex<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + batch_indices.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); copy_to_kv_cache_non_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(new_v.data_ptr()), @@ -285,6 +334,13 @@ void copy_to_kv_cache_non_paged_launcher( } if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && d_k == d_v) { printf("-------- 2 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); + copy_to_kv_cache_non_paged_kernel_reindex<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + batch_indices.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); copy_to_kv_cache_non_paged_kernel_same_d<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(new_v.data_ptr()), @@ -300,9 +356,9 @@ void copy_to_kv_cache_non_paged_launcher( copy_to_kv_cache_non_paged_kernel_q<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(k_cache.data_ptr()), - batch_indices.data_ptr(), + //batch_indices.data_ptr(), step_lens.data_ptr(), - seq_lens.data_ptr(), + //seq_lens.data_ptr(), qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); } if (new_k.data_ptr() == nullptr) { diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 767d6dbc0f..3d028e534e 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -176,12 +176,20 @@ def step( The value cache tensor containing previous and the current tokens """ k_cache, v_cache = self.cache[layer_number] + #kk=k_cache.clone() + #k_cache1 = kk[self.batch_indices].contiguous() + #k_cache = k_cache[self.batch_indices].contiguous() + #v_cache = v_cache[self.batch_indices].contiguous() step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] #h=self.num_heads #16 #d=self.head_dim_k #64 #b=self.max_batch_size #4 - max_ctx_len=k.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + max_ctx_len=1 #k.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + if qkv_format == "bshd": + max_ctx_len=k.shape[1] + if qkv_format == "sbhd": + max_ctx_len=k.shape[0] max_seq_len=self.max_seqlen #k_cache.shape[1] #64 #128 max_ctx_tokens=k.shape[0] max_tokens=k_cache.shape[0]*k_cache.shape[1] @@ -195,6 +203,11 @@ def step( self.batch_indices, step_lens, seq_lens, QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, self.max_batch_size, max_ctx_len, max_seq_len)#, max_ctx_tokens, max_tokens) + #print(k_cache1[0, :2, 0, :4]) + #print(k_cache1[1, :2, 0, :4]) + #print(k_cache[0, :2, 0, :4]) + #print(k_cache[1, :2, 0, :4]) + self.cache[layer_number] = k_cache, v_cache return k_cache, v_cache, None # #prev_batch_size = len(self.sequences) From 737f45a3576fdf87eeda9a5ce5bd8a896cb6f086 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 14 Feb 2025 12:43:42 -0800 Subject: [PATCH 079/239] WIP: non-paged, using paged kernel Signed-off-by: Charlene Yang --- .../pytorch/csrc/extensions/attention.cu | 127 +++++++++++++++++- 1 file changed, 125 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9dbbab12ca..4357b40c3e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -13,6 +13,115 @@ using namespace transformer_engine::fused_attn; constexpr int block_size = 512; constexpr int ctas_per_sm = 4; +template +__global__ void copy_to_kv_cache_paged_kernel( + scalar_t* new_k, scalar_t* new_v, + scalar_t* k_cache, scalar_t* v_cache, + int* page_table, + int* step_lens, + int* seq_lens, + NVTE_QKV_Format qkv_format, + int h, int d, + int b, int max_ctx_len, int max_seq_len, + int max_ctx_tokens, int max_tokens, + int max_pages_per_seq) { + int page_size = max_seq_len / max_pages_per_seq; + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // batch_indices, step_lens, seq_lens: [b + 1] + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int* page_list = page_table + batch_idx * max_pages_per_seq; + int new_token_offset = batch_idx * max_ctx_len; + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; + int token_idx = (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; + for (int j = 0; j < h * d; j ++) { + *(k_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_k + (new_token_offset + i) * h * d +j); + *(v_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_v + (new_token_offset + i) * h * d +j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int* page_list = page_table + batch_idx * max_pages_per_seq; + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; + int token_idx = (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; + for (int j = 0; j < h * d; j ++) { + *(k_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_k + (i * b + batch_idx) * h * d +j); + *(v_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_v + (i * b + batch_idx) * h * d +j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + // no padding between sequences in new_k and new_v + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int* page_list = page_table + batch_idx * max_pages_per_seq; + int new_token_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + new_token_offset += step_lens[t]; + } + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; + int token_idx = (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; + for (int j = 0; j < h * d; j ++) { + *(k_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_k + (new_token_offset + i) * h * d +j); + *(v_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_v + (new_token_offset + i) * h * d +j); + } + } + } + } +} +//template +//void copy_to_kv_cache_paged_launcher( +// torch::Tensor new_k, torch::Tensor new_v, +// torch::Tensor k_cache, torch::Tensor v_cache, +// torch::Tensor page_table, +// torch::Tensor step_lens, +// torch::Tensor seq_lens, +// NVTE_QKV_Format qkv_format, +// int h, int d, +// int b, int max_ctx_len, int max_seq_len, +// int max_ctx_tokens, int max_tokens, +// int max_pages_per_seq) { +// copy_to_kv_cache_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( +// reinterpret_cast(new_k.data_ptr()), +// reinterpret_cast(new_v.data_ptr()), +// reinterpret_cast(k_cache.data_ptr()), +// reinterpret_cast(v_cache.data_ptr()), +// page_table.data_ptr(), +// step_lens.data_ptr(), +// seq_lens.data_ptr(), +// qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); +//} +// +//void copy_to_kv_cache_paged( +// torch::Tensor new_k, torch::Tensor new_v, +// torch::Tensor k_cache, torch::Tensor v_cache, +// torch::Tensor page_table, +// torch::Tensor step_lens, +// torch::Tensor seq_lens, +// NVTE_QKV_Format qkv_format, +// int h, int d, +// int b, int max_ctx_len, int max_seq_len, +// int max_ctx_tokens, int max_tokens, +// int max_pages_per_seq) { +// if (k_cache.scalar_type() == at::ScalarType::Half) { +// using dtype = at::Half; +// copy_to_kv_cache_paged_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); +// +// } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { +// using dtype = at::BFloat16; +// copy_to_kv_cache_paged_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); +// } else if (k_cache.scalar_type() == at::ScalarType::Float) { +// using dtype = float; +// copy_to_kv_cache_paged_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); +// } else { +// NVTE_ERROR("Unsupported dtype.\n"); +// } +//} + + template __global__ void copy_to_kv_cache_non_paged_kernel( scalar_t* new_k, scalar_t* new_v, @@ -301,6 +410,11 @@ __global__ void copy_to_kv_cache_non_paged_kernel_reindex( } } } + if (blockIdx.x == 0) { + for (int batch_idx = threadIdx.x; batch_idx < actual_b; batch_idx ++) { + batch_indices[batch_idx] = batch_idx; + } + } //} } template @@ -341,7 +455,16 @@ void copy_to_kv_cache_non_paged_launcher( step_lens.data_ptr(), seq_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); - copy_to_kv_cache_non_paged_kernel_same_d<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( +// copy_to_kv_cache_non_paged_kernel_same_d<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( +// reinterpret_cast(new_k.data_ptr()), +// reinterpret_cast(new_v.data_ptr()), +// reinterpret_cast(k_cache.data_ptr()), +// reinterpret_cast(v_cache.data_ptr()), +// batch_indices.data_ptr(), +// step_lens.data_ptr(), +// seq_lens.data_ptr(), +// qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); + copy_to_kv_cache_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_k.data_ptr()), reinterpret_cast(new_v.data_ptr()), reinterpret_cast(k_cache.data_ptr()), @@ -349,7 +472,7 @@ void copy_to_kv_cache_non_paged_launcher( batch_indices.data_ptr(), step_lens.data_ptr(), seq_lens.data_ptr(), - qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); + qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len, 1, 1, 1); } if (new_v.data_ptr() == nullptr) { printf("-------- 3 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); From ac015efcef819dcd29ddbcbb45e7141aa92d7931 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 14 Feb 2025 14:27:53 -0800 Subject: [PATCH 080/239] WIP: restructure kernels Signed-off-by: Charlene Yang --- transformer_engine/pytorch/attention.py | 23 +- transformer_engine/pytorch/csrc/extensions.h | 20 +- .../pytorch/csrc/extensions/attention.cu | 661 +++++++----------- .../pytorch/csrc/extensions/pybind.cpp | 4 +- .../pytorch/kv_cache_manager_non_paged.py | 4 +- 5 files changed, 283 insertions(+), 429 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 088ef8db65..7194d45abf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1505,11 +1505,12 @@ def update_cache( print('q xxxxxxxxxxxx ',self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, max_ctx_len, max_seq_len)#, max_ctx_tokens, max_tokens) # TODO: batch_indices - tex.copy_to_kv_cache_non_paged( - q, self.q_dummy, q_buffer, self.q_dummy, - self.batch_indices, step_lens, step_lens, - QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, - max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) + tex.reshape_q(q, q_buffer, step_lens, QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.max_batch_size, max_ctx_len, max_seq_len) + #tex.copy_to_kv_cache_non_paged( + # q, self.q_dummy, q_buffer, self.q_dummy, + # self.batch_indices, step_lens, step_lens, + # QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, + # max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) #q = q_buffer #q_buffer = q_buffer_copy #torch.save(q_buffer, 'q_buffer.pt') @@ -8277,11 +8278,13 @@ def forward( #print('o xxxxxxxxxxxx ',step_lens, #self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, # max_ctx_len, max_seq_len, output.shape, output_buffer.shape)#, max_ctx_tokens, max_tokens) # TODO: batch_indices - tex.copy_to_kv_cache_non_paged( - inference_params.q_dummy, output, inference_params.q_dummy, output_buffer, - inference_params.batch_indices, step_lens, step_lens, - QKVFormat[qkv_format], inference_params.num_heads_q, inference_params.head_dim_q, inference_params.head_dim_q, inference_params.max_batch_size, - max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) + tex.reshape_o(output, output_buffer, step_lens, + inference_params.num_heads_q, inference_params.head_dim_q, inference_params.max_batch_size, max_seq_len) #, max_ctx_tokens, max_tokens) + #tex.copy_to_kv_cache_non_paged( + # inference_params.q_dummy, output, inference_params.q_dummy, output_buffer, + # inference_params.batch_indices, step_lens, step_lens, + # QKVFormat[qkv_format], inference_params.num_heads_q, inference_params.head_dim_q, inference_params.head_dim_q, inference_params.max_batch_size, + # max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) output = output_buffer.view(output_buffer.shape[0], -1) return output diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ba2f8e4530..9dc35e0d5a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -34,15 +34,27 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T /*************************************************************************************************** * Attention **************************************************************************************************/ -void copy_to_kv_cache_non_paged( +void reshape_q( + torch::Tensor new_q, torch::Tensor q_buffer, + torch::Tensor step_lens, + NVTE_QKV_Format qkv_format, + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len); + +void reshape_o( + torch::Tensor output, torch::Tensor output_buffer, + torch::Tensor step_lens, + int h_o, int d_o, int b, int max_seq_len); + +void copy_to_kv_cache( torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, - torch::Tensor batch_indices, + torch::Tensor page_table, torch::Tensor step_lens, torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len); + int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged); NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 4357b40c3e..1399faa955 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -14,511 +14,348 @@ constexpr int block_size = 512; constexpr int ctas_per_sm = 4; template -__global__ void copy_to_kv_cache_paged_kernel( - scalar_t* new_k, scalar_t* new_v, - scalar_t* k_cache, scalar_t* v_cache, - int* page_table, +__global__ void reshape_q_kernel( + scalar_t* new_q, + scalar_t* q_buffer, int* step_lens, - int* seq_lens, NVTE_QKV_Format qkv_format, - int h, int d, - int b, int max_ctx_len, int max_seq_len, - int max_ctx_tokens, int max_tokens, - int max_pages_per_seq) { - int page_size = max_seq_len / max_pages_per_seq; - // new_k, new_v: qkv_format; k_cache, v_cache: bshd - // batch_indices, step_lens, seq_lens: [b + 1] + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + // new_q: qkv_format; q_buffer: bshd + // step_lens: [b + 1] if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int* page_list = page_table + batch_idx * max_pages_per_seq; - int new_token_offset = batch_idx * max_ctx_len; - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; - int token_idx = (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; - for (int j = 0; j < h * d; j ++) { - *(k_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_k + (new_token_offset + i) * h * d +j); - *(v_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_v + (new_token_offset + i) * h * d +j); - } + int num_elts = step_lens[batch_idx] * h_q * d_q; + int new_token_offset = batch_idx * max_ctx_len * h_q * d_q; + int cache_offset = batch_idx * max_seq_len * h_q * d_q; + scalar_t* new_q_token = new_q + new_token_offset; + scalar_t* q_buffer_token = q_buffer + cache_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(q_buffer_token + i) = *(new_q_token + i); } } } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int* page_list = page_table + batch_idx * max_pages_per_seq; + int cache_offset = batch_idx * max_seq_len; for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; - int token_idx = (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; - for (int j = 0; j < h * d; j ++) { - *(k_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_k + (i * b + batch_idx) * h * d +j); - *(v_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_v + (i * b + batch_idx) * h * d +j); + for (int j = 0; j < h_q * d_q; j ++) { + *(q_buffer + (cache_offset + i) * h_q * d_q + j) = *(new_q + (i * b + batch_idx) * h_q * d_q +j); } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - // no padding between sequences in new_k and new_v for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int* page_list = page_table + batch_idx * max_pages_per_seq; + int num_elts = step_lens[batch_idx] * h_q * d_q; int new_token_offset = 0; for (int t = 0; t < batch_idx; t ++) { new_token_offset += step_lens[t]; } - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; - int token_idx = (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; - for (int j = 0; j < h * d; j ++) { - *(k_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_k + (new_token_offset + i) * h * d +j); - *(v_cache + (page_idx * page_size + token_idx) * h * d + j) = *(new_v + (new_token_offset + i) * h * d +j); - } + new_token_offset = new_token_offset * h_q * d_q; + int cache_offset = batch_idx * max_seq_len * h_q * d_q; + scalar_t* new_q_token = new_q + new_token_offset; + scalar_t* q_buffer_token = q_buffer + cache_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(q_buffer_token + i) = *(new_q_token + i); } } } } -//template -//void copy_to_kv_cache_paged_launcher( -// torch::Tensor new_k, torch::Tensor new_v, -// torch::Tensor k_cache, torch::Tensor v_cache, -// torch::Tensor page_table, -// torch::Tensor step_lens, -// torch::Tensor seq_lens, -// NVTE_QKV_Format qkv_format, -// int h, int d, -// int b, int max_ctx_len, int max_seq_len, -// int max_ctx_tokens, int max_tokens, -// int max_pages_per_seq) { -// copy_to_kv_cache_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( -// reinterpret_cast(new_k.data_ptr()), -// reinterpret_cast(new_v.data_ptr()), -// reinterpret_cast(k_cache.data_ptr()), -// reinterpret_cast(v_cache.data_ptr()), -// page_table.data_ptr(), -// step_lens.data_ptr(), -// seq_lens.data_ptr(), -// qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); -//} -// -//void copy_to_kv_cache_paged( -// torch::Tensor new_k, torch::Tensor new_v, -// torch::Tensor k_cache, torch::Tensor v_cache, -// torch::Tensor page_table, -// torch::Tensor step_lens, -// torch::Tensor seq_lens, -// NVTE_QKV_Format qkv_format, -// int h, int d, -// int b, int max_ctx_len, int max_seq_len, -// int max_ctx_tokens, int max_tokens, -// int max_pages_per_seq) { -// if (k_cache.scalar_type() == at::ScalarType::Half) { -// using dtype = at::Half; -// copy_to_kv_cache_paged_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); -// -// } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { -// using dtype = at::BFloat16; -// copy_to_kv_cache_paged_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); -// } else if (k_cache.scalar_type() == at::ScalarType::Float) { -// using dtype = float; -// copy_to_kv_cache_paged_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h, d, b, max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens, max_pages_per_seq); -// } else { -// NVTE_ERROR("Unsupported dtype.\n"); -// } -//} +template +void reshape_q_launcher( + torch::Tensor new_q, torch::Tensor q_buffer, + torch::Tensor step_lens, + NVTE_QKV_Format qkv_format, + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + printf("-------- 3 %p %d %d %d \n"); //, new_v.data_ptr(), h_kv, d_k, d_v); + reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_q.data_ptr()), + reinterpret_cast(q_buffer.data_ptr()), + step_lens.data_ptr(), + qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); +} + +void reshape_q( + torch::Tensor new_q, torch::Tensor q_buffer, + torch::Tensor step_lens, + NVTE_QKV_Format qkv_format, + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), + "new_q and q_buffer must be of the same data type."); + NVTE_CHECK( + qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_THD, + "qkv_format must be {BSHD, SBHD, THD}."); + if (q_buffer.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + } else if (q_buffer.scalar_type() == at::ScalarType::Float) { + using dtype = float; + reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); +// } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { +// using dtype = at::kFloat8_e4m3fn; +// reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); +// } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { +// using dtype = at::kFloat8_e5m2; +// reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } +} template -__global__ void copy_to_kv_cache_non_paged_kernel( - scalar_t* new_k, scalar_t* new_v, - scalar_t* k_cache, scalar_t* v_cache, - int* batch_indices, +__global__ void reshape_o_kernel( + scalar_t* output, + scalar_t* output_buffer, int* step_lens, - int* seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len) { - // new_k, new_v: qkv_format; k_cache, v_cache: bshd - // batch_indices, step_lens, seq_lens: [b + 1] - if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts_k = step_lens[batch_idx] * h_kv * d_k; - int num_elts_v = step_lens[batch_idx] * h_kv * d_v; - int new_token_offset_k = batch_idx * max_ctx_len * h_kv * d_k; - int new_token_offset_v = batch_idx * max_ctx_len * h_kv * d_v; - int cache_offset_k = batch_idx * max_seq_len * h_kv * d_k + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_k; - int cache_offset_v = batch_idx * max_seq_len * h_kv * d_v + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_v; - - scalar_t* new_k_token = new_k + new_token_offset_k; - scalar_t* k_cache_token = k_cache + cache_offset_k; - scalar_t* new_v_token = new_v + new_token_offset_v; - scalar_t* v_cache_token = v_cache + cache_offset_v; - - for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { - *(k_cache_token + i) = *(new_k_token + i); - } - for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { - *(v_cache_token + i) = *(new_v_token + i); - } + int h_o, int d_o, + int b, int max_seq_len) { + // output: bshd; output_buffer: thd; + // step_lens: [b + 1] + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = step_lens[batch_idx] * h_o * d_o; + int output_offset = batch_idx * max_seq_len * h_o * d_o; + int output_buffer_offset = 0; + for (int t = 0; t < batch_idx; t ++) { + output_buffer_offset += step_lens[t]; } - } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h_kv * d_k; j ++) { - *(k_cache + (cache_offset + i) * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k +j); - } - for (int j = 0; j < h_kv * d_v; j ++) { - *(v_cache + (cache_offset + i) * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v +j); - } - } + output_buffer_offset = output_buffer_offset * h_o * d_o; + scalar_t* output_token = output + output_offset; + scalar_t* output_buffer_token = output_buffer + output_buffer_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(output_buffer_token + i) = *(output_token + i); } - } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - // no padding between sequences in new_k and new_v - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts_k = step_lens[batch_idx] * h_kv * d_k; - int num_elts_v = step_lens[batch_idx] * h_kv * d_v; - int new_token_offset = 0; - for (int t = 0; t < batch_idx; t ++) { - new_token_offset += step_lens[t]; - } - int cache_offset = batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]; + } +} - scalar_t* new_k_token = new_k + new_token_offset * h_kv * d_k; - scalar_t* k_cache_token = k_cache + cache_offset * h_kv * d_k; - scalar_t* new_v_token = new_v + new_token_offset * h_kv * d_v; - scalar_t* v_cache_token = v_cache + cache_offset * h_kv * d_v; +template +void reshape_o_launcher( + torch::Tensor output, torch::Tensor output_buffer, + torch::Tensor step_lens, + int h_o, int d_o, int b, int max_seq_len) { + printf("-------- 4 %p %d %d %d \n"); //, new_v.data_ptr(), h_kv, d_k, d_v); + reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_buffer.data_ptr()), + step_lens.data_ptr(), + h_o, d_o, b, max_seq_len); +} - for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { - *(k_cache_token + i) = *(new_k_token + i); - } - for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { - *(v_cache_token + i) = *(new_v_token + i); - } - } +void reshape_o( + torch::Tensor output, torch::Tensor output_buffer, + torch::Tensor step_lens, + int h_o, int d_o, int b, int max_seq_len) { + NVTE_CHECK( + output.scalar_type() == output_buffer.scalar_type(), + "output and output_buffer must be of the same data type."); + if (output.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); + } else if (output.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); + } else if (output.scalar_type() == at::ScalarType::Float) { + using dtype = float; + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); +// } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { +// using dtype = at::kFloat8_e4m3fn; +// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); +// } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { +// using dtype = at::kFloat8_e5m2; +// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); } } + template -__global__ void copy_to_kv_cache_non_paged_kernel_same_d( - scalar_t* new_k, scalar_t* new_v, +__global__ void reindex_kv_cache_kernel( scalar_t* k_cache, scalar_t* v_cache, int* batch_indices, int* step_lens, int* seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_kv, - int b, int max_ctx_len, int max_seq_len) { - // new_k, new_v: qkv_format; k_cache, v_cache: bshd + int h_kv, int d_k, int d_v, int b, + int max_seq_len) { + // k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] - if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_kv * d_kv; - int new_token_offset = batch_idx * max_ctx_len * h_kv * d_kv; - int cache_offset = batch_idx * max_seq_len * h_kv * d_kv + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; - - scalar_t* new_k_token = new_k + new_token_offset; - scalar_t* k_cache_token = k_cache + cache_offset; - scalar_t* new_v_token = new_v + new_token_offset; - scalar_t* v_cache_token = v_cache + cache_offset; - - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(k_cache_token + i) = *(new_k_token + i); - *(v_cache_token + i) = *(new_v_token + i); - } - } - } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int cache_offset = batch_idx * max_seq_len + (seq_lens[batch_idx] - step_lens[batch_idx]); - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h_kv * d_kv; j ++) { - *(k_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_k + (i * b + batch_idx) * h_kv * d_kv +j); - *(v_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_v + (i * b + batch_idx) * h_kv * d_kv +j); - } - } + int actual_b = 0; + for (int i = 1; i < b; i++) { + if (batch_indices[i] > batch_indices[i-1]) { + actual_b = i+1; } - } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - // no padding between sequences in new_k and new_v - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_kv * d_kv; - int new_token_offset = 0; - for (int t = 0; t < batch_idx; t ++) { - new_token_offset += step_lens[t]; + } + for (int batch_idx = 0; batch_idx < actual_b; batch_idx ++) { + for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; token_idx += gridDim.x) { + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; + int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; + int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; + int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; + for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { + *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); } - new_token_offset = new_token_offset * h_kv * d_kv; - int cache_offset = (batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; - - scalar_t* new_k_token = new_k + new_token_offset; - scalar_t* k_cache_token = k_cache + cache_offset; - scalar_t* new_v_token = new_v + new_token_offset; - scalar_t* v_cache_token = v_cache + cache_offset; - - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(k_cache_token + i) = *(new_k_token + i); - *(v_cache_token + i) = *(new_v_token + i); + for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { + *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); } } } + if (blockIdx.x == 0) { + for (int batch_idx = threadIdx.x; batch_idx < actual_b; batch_idx ++) { + batch_indices[batch_idx] = batch_idx; + } + } } + template -__global__ void copy_to_kv_cache_non_paged_kernel_q( - scalar_t* new_k, - scalar_t* k_cache, - //int* batch_indices, +__global__ void copy_to_kv_cache_kernel( + scalar_t* new_k, scalar_t* new_v, + scalar_t* k_cache, scalar_t* v_cache, + int* page_table, int* step_lens, - //int* seq_lens, + int* seq_lens, NVTE_QKV_Format qkv_format, - int h_kv, int d_kv, - int b, int max_ctx_len, int max_seq_len) { - // new_k, new_v: qkv_format; k_cache, v_cache: bshd - // batch_indices, step_lens, seq_lens: [b + 1] + int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq) { + int page_size = max_seq_len / max_pages_per_seq; if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_kv * d_kv; - int new_token_offset = batch_idx * max_ctx_len * h_kv * d_kv; - int cache_offset = batch_idx * max_seq_len * h_kv * d_kv; // + (seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; - - scalar_t* new_k_token = new_k + new_token_offset; - scalar_t* k_cache_token = k_cache + cache_offset; - - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(k_cache_token + i) = *(new_k_token + i); + int* page_list = page_table + batch_idx * max_pages_per_seq; + int new_token_offset = batch_idx * max_ctx_len; + for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; + int token_idx = page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; + for (int j = 0; j < h_kv * d_k; j ++) { + *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (new_token_offset + i) * h_kv * d_k +j); + } + for (int j = 0; j < h_kv * d_v; j ++) { + *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (new_token_offset + i) * h_kv * d_v +j); + } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int cache_offset = batch_idx * max_seq_len; // + (seq_lens[batch_idx] - step_lens[batch_idx]); + int* page_list = page_table + batch_idx * max_pages_per_seq; for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h_kv * d_kv; j ++) { - *(k_cache + (cache_offset + i) * h_kv * d_kv + j) = *(new_k + (i * b + batch_idx) * h_kv * d_kv +j); + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; + int token_idx = page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; + for (int j = 0; j < h_kv * d_k; j ++) { + *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k +j); + } + for (int j = 0; j < h_kv * d_v; j ++) { + *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v +j); } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - // no padding between sequences in new_k and new_v for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_kv * d_kv; + int* page_list = page_table + batch_idx * max_pages_per_seq; int new_token_offset = 0; for (int t = 0; t < batch_idx; t ++) { new_token_offset += step_lens[t]; } - new_token_offset = new_token_offset * h_kv * d_kv; - int cache_offset = batch_idx * max_seq_len * h_kv * d_kv; // + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; - - scalar_t* new_k_token = new_k + new_token_offset; - scalar_t* k_cache_token = k_cache + cache_offset; - - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(k_cache_token + i) = *(new_k_token + i); - } - } - } -} -template -__global__ void copy_to_kv_cache_non_paged_kernel_o( - scalar_t* output, // output in bshd - scalar_t* output_buffer, // output in thd - //int* batch_indices, - int* step_lens, - //int* seq_lens, - NVTE_QKV_Format qkv_format, - int h_q, int d_q, - int b, int max_seq_len) { - // new_k, new_v: qkv_format; k_cache, v_cache: bshd - // batch_indices, step_lens, seq_lens: [b + 1] - if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_q * d_q; - int output_offset = batch_idx * max_seq_len * h_q * d_q; - int output_buffer_offset = 0; - for (int t = 0; t < batch_idx; t ++) { - output_buffer_offset += step_lens[t]; - } - output_buffer_offset = output_buffer_offset * h_q * d_q; - scalar_t* output_token = output + output_offset; - scalar_t* output_buffer_token = output_buffer + output_buffer_offset; - - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - if (batch_idx < 2 && i < 3) { - printf("h_q %d d_q %d, b %d t %d, output_offset %d output_buffer_offset %d, output_buffer_token + i %p output_token + i %p\n",h_q, d_q,batch_idx, i, output_offset, output_buffer_offset, output_buffer_token + i, output_token + i); - } - *(output_buffer_token + i) = *(output_token + i); - } - } - } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int output_buffer_offset = 0; - for (int t = 0; t < batch_idx; t ++) { - output_buffer_offset += step_lens[t]; - } for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h_q * d_q; j ++) { - *(output_buffer + (output_buffer_offset + i) * h_q * d_q + j) = *(output + (i * b + batch_idx) * h_q * d_q +j); + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; + int token_idx = page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; + for (int j = 0; j < h_kv * d_k; j ++) { + *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (new_token_offset + i) * h_kv * d_k +j); + } + for (int j = 0; j < h_kv * d_v; j ++) { + *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (new_token_offset + i) * h_kv * d_v +j); } } } -// } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { -// // no padding between sequences in new_k and new_v -// for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { -// int num_elts = step_lens[batch_idx] * h_kv * d_kv; -// int new_token_offset = 0; -// for (int t = 0; t < batch_idx; t ++) { -// new_token_offset += step_lens[t]; -// } -// new_token_offset = new_token_offset * h_kv * d_kv; -// int cache_offset = (batch_idx * max_seq_len + seq_lens[batch_idx] - step_lens[batch_idx]) * h_kv * d_kv; -// -// scalar_t* new_k_token = new_k + new_token_offset; -// scalar_t* k_cache_token = k_cache + cache_offset; -// -// for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { -// *(k_cache_token + i) = *(new_k_token + i); -// } -// } } } + template -__global__ void copy_to_kv_cache_non_paged_kernel_reindex( - scalar_t* k_cache, scalar_t* v_cache, - int* batch_indices, - int* step_lens, - int* seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len) { - // new_k, new_v: qkv_format; k_cache, v_cache: bshd - // batch_indices, step_lens, seq_lens: [b + 1] - // only support bshd as cache format - //if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { - int actual_b = 0; - for (int i = 1; i < b; i++) { - if (batch_indices[i] > batch_indices[i-1]) { - actual_b = i+1; - } - } - for (int batch_idx = 0; batch_idx < actual_b; batch_idx ++) { - for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; token_idx += gridDim.x) { - //if (blockIdx.x < 2 && threadIdx.x < 3) { - // printf("bid %d tid %d, b %d actual_b %d, len %d, - // output_offset %d output_buffer_offset %d, output_buffer_token + i %p output_token + i %p\n",h_q, d_q, - // batch_idx, token_idx, output_offset, output_buffer_offset, output_buffer_token + i, output_token + i); - //} - int num_elts_k = h_kv * d_k; - int num_elts_v = h_kv * d_v; - int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; - int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; - int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; - int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; - for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { - *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); - } - for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { - *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); - } - } - } - if (blockIdx.x == 0) { - for (int batch_idx = threadIdx.x; batch_idx < actual_b; batch_idx ++) { - batch_indices[batch_idx] = batch_idx; - } - } - //} -} -template -void copy_to_kv_cache_non_paged_launcher( +void copy_to_kv_cache_launcher( torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, - torch::Tensor batch_indices, + torch::Tensor page_table, torch::Tensor step_lens, torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len) { - if (new_v.data_ptr() != nullptr && d_k != d_v) { - printf("-------- 1 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); - copy_to_kv_cache_non_paged_kernel_reindex<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - batch_indices.data_ptr(), - step_lens.data_ptr(), - seq_lens.data_ptr(), - qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); - copy_to_kv_cache_non_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - batch_indices.data_ptr(), - step_lens.data_ptr(), - seq_lens.data_ptr(), - qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); - } - if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && d_k == d_v) { - printf("-------- 2 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); - copy_to_kv_cache_non_paged_kernel_reindex<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - batch_indices.data_ptr(), - step_lens.data_ptr(), - seq_lens.data_ptr(), - qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); -// copy_to_kv_cache_non_paged_kernel_same_d<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( -// reinterpret_cast(new_k.data_ptr()), -// reinterpret_cast(new_v.data_ptr()), -// reinterpret_cast(k_cache.data_ptr()), -// reinterpret_cast(v_cache.data_ptr()), -// batch_indices.data_ptr(), -// step_lens.data_ptr(), -// seq_lens.data_ptr(), -// qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); - copy_to_kv_cache_paged_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - batch_indices.data_ptr(), - step_lens.data_ptr(), - seq_lens.data_ptr(), - qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len, 1, 1, 1); - } - if (new_v.data_ptr() == nullptr) { - printf("-------- 3 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); - copy_to_kv_cache_non_paged_kernel_q<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - //batch_indices.data_ptr(), - step_lens.data_ptr(), - //seq_lens.data_ptr(), - qkv_format, h_kv, d_k, b, max_ctx_len, max_seq_len); - } - if (new_k.data_ptr() == nullptr) { - printf("-------- 4 %p %d %d %d \n", new_v.data_ptr(), h_kv, d_k, d_v); - copy_to_kv_cache_non_paged_kernel_o<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - //batch_indices.data_ptr(), - step_lens.data_ptr(), - //seq_lens.data_ptr(), - qkv_format, h_kv, d_k, b, max_seq_len); + int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged) { + // 1. new_k, new_v: qkv_format; k_cache, v_cache: bshd + // 2. step_lens, seq_lens (step lens included): [b + 1] + // 3. non-paged cache can be considered a special case of paged cache, + // where page_table = [b, 1] and max_pages_per_seq = 1 + // 4. is_non_paged = True forces re-indexing of the cache based on page_table, + // i.e. page_table = [0, 3, 1, 2] will be rearranged to [0, 1, 1, 2] + // 5. assumes k_cache and v_cache have the same page_table + // 6. for THD, assumes no padding between sequences in new_k and new_v + if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && + k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) { + printf("-------- 1 %p %d %d %d \n"); //, new_v.data_ptr(), h_kv, d_k, d_v); + if (is_non_paged) { + reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + page_table.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + h_kv, d_k, d_v, b, max_seq_len); + } + copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + page_table.data_ptr(), + step_lens.data_ptr(), + seq_lens.data_ptr(), + qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq); } } -void copy_to_kv_cache_non_paged( +// copy new K/V tokens to KV cache +void copy_to_kv_cache( torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, - torch::Tensor batch_indices, + torch::Tensor page_table, torch::Tensor step_lens, torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, - int b, int max_ctx_len, int max_seq_len) { + int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged) { + NVTE_CHECK( + k_cache.scalar_type() == v_cache.scalar_type() && + new_k.scalar_type() == new_v.scalar_type() && + new_k.scalar_type() == k_cache.scalar_type(), + "new_k, new_v, k_cache and v_cache must be of the same data type."); + NVTE_CHECK( + qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_THD, + "qkv_format must be {BSHD, SBHD, THD}."); if (k_cache.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); } else if (k_cache.scalar_type() == at::ScalarType::Float) { using dtype = float; - copy_to_kv_cache_non_paged_launcher(new_k, new_v, k_cache, v_cache, batch_indices, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len); + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); +// } else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) { +// using dtype = at::kFloat8_e4m3fn; +// copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); +// } else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) { +// using dtype = at::kFloat8_e5m2; +// copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); } else { - NVTE_ERROR("Unsupported dtype.\n"); + NVTE_ERROR("Unsupported dtype for KV cache.\n"); } } + // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 33ecd3a7cd..ba4f00e77f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -171,7 +171,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); - m.def("copy_to_kv_cache_non_paged", ©_to_kv_cache_non_paged, "Copy KV to non-paged KV cache"); + m.def("copy_to_kv_cache", ©_to_kv_cache, "Copy new KV tokens to KV cache"); + m.def("reshape_q", &reshape_q, "Reshape Q for THD before attention"); + m.def("reshape_o", &reshape_o, "Reshape O for THD after attention"); m.def("fused_attn_fwd", &fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &fused_attn_bwd, diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 3d028e534e..d05df0dd3d 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -198,11 +198,11 @@ def step( #print('seq_lens ', seq_lens) #print('self.batch_indices ', self.batch_indices) print('lensss ', max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) - tex.copy_to_kv_cache_non_paged( + tex.copy_to_kv_cache( k, v, k_cache, v_cache, self.batch_indices, step_lens, seq_lens, QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, self.max_batch_size, - max_ctx_len, max_seq_len)#, max_ctx_tokens, max_tokens) + max_ctx_len, max_seq_len, 1, True)#, max_ctx_tokens, max_tokens) #print(k_cache1[0, :2, 0, :4]) #print(k_cache1[1, :2, 0, :4]) #print(k_cache[0, :2, 0, :4]) From 45e9d8b6d82bc28fe58574dd647aee71db438e20 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 14 Feb 2025 15:35:17 -0800 Subject: [PATCH 081/239] [JAX] Lint Fix (#1484) JAX Lint Fix Signed-off-by: Phuong Nguyen --- transformer_engine/jax/flax/module.py | 1 - transformer_engine/jax/flax/transformer.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 2190c6df6c..23bc8d3602 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1251,7 +1251,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): # Remove act axis z = jnp.reshape(z, (*z.shape[:-2], -1)) z = z.astype(self.dtype) - # import pdb; pdb.set_trace() z = nn.Dropout( rate=self.intermediate_dropout_rate, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fbae73f131..49491a5cda 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -987,9 +987,8 @@ def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", self.weight_dtype + 1.0, "fan_in", "normal", dtype=self.weight_dtype ) - self.kernel_init = _kernel_init.astype(self.dtype) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() From dfbf4ddecaab3c09a933c5bcb64048281b5fd7bf Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Sat, 15 Feb 2025 07:35:54 +0800 Subject: [PATCH 082/239] [JAX] Fix issues when mask/sequence_descriptor is None (#1477) Fix issues when mask/sequence_descriptor is None Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen --- tests/jax/test_fused_attn.py | 17 ++++++++++------- transformer_engine/jax/attention.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index beaf18cea3..ff4139ee51 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -556,13 +556,16 @@ def generate_random_segment_ids( else: match self.seq_desc_format: case SeqDescFormat.Mask: - self.sequence_desciptor = make_mask( - self.segment_ids_q, - self.segment_ids_kv, - self.segment_pos_q, - self.segment_pos_kv, - self.attn_mask_type, - ) + if self.attn_mask_type == AttnMaskType.NO_MASK: + self.sequence_desciptor = None + else: + self.sequence_desciptor = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) case SeqDescFormat.Seqlens: self.sequence_desciptor = SequenceDescriptor.from_seqlens( ( diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 09128b013b..a8245b533e 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -950,7 +950,7 @@ def fused_attn( AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, QKVLayout.T3HD, 0.125, 0, True, 3) """ - if isinstance(sequence_descriptor, jnp.ndarray): + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): warnings.warn( "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.", From af7b2b44dd6173c9b3049f306c0773a938feceae Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Sat, 15 Feb 2025 07:36:34 +0800 Subject: [PATCH 083/239] [JAX] Expose THD format to the flax module (#1480) * Expose THD to flex MHA module Signed-off-by: Reese Wang * Enhance docs Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen --- .../jax/cpp_extensions/attention.py | 7 +- transformer_engine/jax/flax/transformer.py | 71 ++++++++++++++----- 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 1c32ef4cba..51ff87ced1 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -645,6 +645,8 @@ def partition(config, mesh, arg_infos, result_infos): ) arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[4] = seed_sharding + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) impl = partial(FusedAttnFwdPrimitive.impl, config=config) @@ -1042,7 +1044,10 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) def sharded_impl( diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 49491a5cda..100557404b 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -24,7 +24,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax -from ..attention import AttnBiasType, AttnMaskType, QKVLayout +from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import fused_attn from ..softmax import SoftmaxType @@ -267,6 +267,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me scale_factor: Optional[float] = None transpose_batch_sequence: bool = False window_size: Optional[Tuple[int, int]] = None + max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" @@ -276,7 +277,7 @@ def __call__( query: Array, key: Array, value: Array, - mask: Optional[Array] = None, + sequence_descriptor: Optional[SequenceDescriptor] = None, bias: Optional[Array] = None, *, dropout_rng: Optional[PRNGKey] = None, @@ -293,8 +294,7 @@ def __call__( scale_factor = self.scale_factor del self.scale_factor - # TODO(rewang): integrate THD format - if self.qkv_layout == QKVLayout.BS3HD: + if self.qkv_layout.is_qkvpacked(): """qkvpacked format, treat query: qkvpacked tensor, shape = [..., 3, h, d] key: ignore @@ -306,7 +306,7 @@ def __call__( x = fused_attn( (qkv_packed,), bias, - mask, + sequence_descriptor, seed, attn_mask_type=self.attn_mask_type, attn_bias_type=self.attn_bias_type, @@ -315,10 +315,11 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, ) - elif self.qkv_layout == QKVLayout.BSHD_BS2HD: + elif self.qkv_layout.is_kvpacked(): """kvpacked format, treat query: query tensor, shape = [..., h, d] key: kvpacked tensor, shape = [..., 2, h, d] @@ -331,7 +332,7 @@ def __call__( x = fused_attn( (query, kv_packed), bias, - mask, + sequence_descriptor, seed, attn_mask_type=self.attn_mask_type, attn_bias_type=self.attn_bias_type, @@ -340,10 +341,11 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, ) - elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: + elif self.qkv_layout.is_separate(): if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) key = key.transpose([1, 0, 2, 3]) @@ -351,7 +353,7 @@ def __call__( x = fused_attn( (query, key, value), bias, - mask, + sequence_descriptor, seed, attn_mask_type=self.attn_mask_type, attn_bias_type=self.attn_bias_type, @@ -360,6 +362,7 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, ) @@ -437,6 +440,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + .. note:: THD format only supports 'padding' or 'causal_padding' mask type. + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. @@ -451,13 +456,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods qkv_layout: str, default = 'bshd_bshd_bshd' Specifies the dimensional layout format for the query, key, and value tensors in __call__(). It indicates how the inputs are processed. - Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where + Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}. * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d]. key and value arguments in :attr:`__call__()` are ignored in this layout. * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored. * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d]. + * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple + sequences to be packed in a batch, also known as sequence packing. Explanation of denotations: @@ -476,6 +483,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None Sliding window size. The default value is no sliding window. + max_segments_per_seq: Optional[int], default = 1 + The maximum number of segments per sequence, also used for THD format (sequence packing). context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. @@ -502,6 +511,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods scale_factor: Optional[float] = None transpose_batch_sequence: bool = True window_size: Optional[Tuple[int, int]] = None + max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" @@ -511,10 +521,11 @@ def __call__( query: Array, key: Array, value: Array, - mask: Optional[Array] = None, + sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None, bias: Optional[Array] = None, *, deterministic: bool = False, + mask: Optional[Union[SequenceDescriptor, Array]] = None, ) -> Array: """ Parameters @@ -542,6 +553,15 @@ def __call__( Output tensors. """ + if mask is not None: + if sequence_descriptor is not None: + raise ValueError( + "sequence_descriptor and mask cannot be provided at the same time." + ) + warnings.warn("mask is deprecated, please use sequence_descriptor instead.") + sequence_descriptor = mask + del mask + # For internal API, we use enum to maintain if self.attn_bias_type is None: attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS @@ -604,16 +624,18 @@ def __call__( if not use_fused_attn: # unfused attention only supports splitted query, key, value - if qkv_layout == QKVLayout.BS3HD: + if qkv_layout.is_qkvpacked(): query, key, value = jnp.split(query, [1, 2], axis=-3) query, key, value = map( functools.partial(jnp.squeeze, axis=-3), [query, key, value] ) - elif qkv_layout == QKVLayout.BSHD_BS2HD: + elif qkv_layout.is_kvpacked(): key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert qkv_layout.is_separate() + + assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray) x = _UnfusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -625,7 +647,15 @@ def __call__( scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, window_size=self.window_size, - )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) + )( + query, + key, + value, + sequence_descriptor, + bias, + dropout_rng=dropout_rng, + deterministic=deterministic, + ) else: x = _FusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -637,9 +667,18 @@ def __call__( transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, - )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) + )( + query, + key, + value, + sequence_descriptor, + bias, + dropout_rng=dropout_rng, + deterministic=deterministic, + ) return x From b39397c541292f336c5964dd1661d80c08dc4c78 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 14 Feb 2025 17:11:29 -0800 Subject: [PATCH 084/239] Changed VERSION to 2.2.0.dev0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index eb5820cd2d..6b959d99e8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.1.0.dev0 +2.2.0.dev0 From e52868bd24a44b16df67a5208791affe722f5019 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 14 Feb 2025 21:06:30 -0800 Subject: [PATCH 085/239] WIP: paged, CG Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 15 +- transformer_engine/pytorch/attention.py | 4 +- .../pytorch/csrc/extensions/attention.cu | 11 +- .../pytorch/kv_cache_manager_non_paged.py | 15 +- .../pytorch/kv_cache_manager_paged.py | 211 +++++++++++------- 5 files changed, 170 insertions(+), 86 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 3b740392f7..76c1feb7f1 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -95,8 +95,8 @@ def __init__( self.poisson_rate = torch.randint(1, max_batch_size, [1]).item() interval_dist = Exponential(self.poisson_rate) arrival_intervals = interval_dist.sample((total_requests,)) - #self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") - self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") + self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") + #self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() # initialize tensors @@ -208,7 +208,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("dtype", [torch.float16])#param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) -@pytest.mark.parametrize("is_paged", [False])#, True]) +@pytest.mark.parametrize("is_paged", [False, True]) @pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): @@ -276,6 +276,15 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): dtype=dtype, device="cuda", ) + #print('k_full[0, 0, 0, :4]', k[0, 0, 0, :4]) + print('k_full[7, 46:48, 0, :4]', k[7, 46:48, 0, :4]) + #print('k_full[1, :2, 0, :4]', k[1, :2, 0, :4]) + #print('k_full[1, 6, 0, :4]', k[1, 6, 0, :4]) + #print('k_full[0, 17, 0, :4]', k[0, 17, 0, :4]) + #print('k_full[2, 22, 0, :4]', k[2, 22, 0, :4]) + #print('k_full[5, 14, 0, :4]', k[5, 14, 0, :4]) + #print('k_full[6, 12, 0, :4]', k[6, 12, 0, :4]) + # generate reference results logger.info("=== Generating all tokens at once ===") diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7194d45abf..ad316540fc 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7973,7 +7973,7 @@ def forward( ) # convert qkv layout to its corresponding paged attention layout if inference_params is not None and inference_params.is_paged: - qkv_layout = "paged_kv_" + qkv_format + "_2" + inference_params.qkv_format + qkv_layout = "paged_kv_" + qkv_format + "_2" + inference_params.cache_qkv_format global _alibi_cache if alibi_slopes is not None: @@ -8252,6 +8252,8 @@ def forward( #ooo = output.view(output.shape[:2], -1) #print('output ', output[0,0,:4]) #print('output ', output[1,0,:4]) + #print('output ', output[0,0,:4]) + #print('output ', output[1,6,:4]) if inference_params.input_qkv_format == "bshd": output = output[:batch_size, :max_seqlen_q].contiguous() if inference_params.input_qkv_format == "sbhd": diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 1399faa955..f168eb4178 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -181,12 +181,15 @@ __global__ void reindex_kv_cache_kernel( int max_seq_len) { // k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] - int actual_b = 0; - for (int i = 1; i < b; i++) { - if (batch_indices[i] > batch_indices[i-1]) { + int actual_b = b; + for (int i = 0; i < b-1; i++) { + if (batch_indices[i+1] < batch_indices[i]) { actual_b = i+1; } } + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("actual_b is %d\n", actual_b); + } for (int batch_idx = 0; batch_idx < actual_b; batch_idx ++) { for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; token_idx += gridDim.x) { int num_elts_k = h_kv * d_k; @@ -204,7 +207,7 @@ __global__ void reindex_kv_cache_kernel( } } if (blockIdx.x == 0) { - for (int batch_idx = threadIdx.x; batch_idx < actual_b; batch_idx ++) { + for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx ++) { batch_indices[batch_idx] = batch_idx; } } diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index d05df0dd3d..a750518edc 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -110,7 +110,7 @@ def prepare( ): # TODO: remove self.sequences = sequences - #self.step_dict = step_dict + self.step_dict = step_dict prev_batch_size = len(self.sequences) batch_size = len(step_dict) @@ -197,12 +197,21 @@ def step( #print('step_lens ', step_lens) #print('seq_lens ', seq_lens) #print('self.batch_indices ', self.batch_indices) - print('lensss ', max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + #print('lensss ', qkv_format, step_lens, seq_lens,max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) tex.copy_to_kv_cache( k, v, k_cache, v_cache, self.batch_indices, step_lens, seq_lens, - QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, self.max_batch_size, + QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, len(self.step_dict), #self.max_batch_size, max_ctx_len, max_seq_len, 1, True)#, max_ctx_tokens, max_tokens) + #print('self.batch_indices after', self.batch_indices) + #print('k_cache[0, 0, 0, :4]', k_cache[0, 0, 0, :4]) + #print('k_cache[0, 46:48, 0, :4]', k_cache[0, 46:48, 0, :4]) + #print('k_cache[1, :2, 0, :4]', k_cache[1, :2, 0, :4]) + #print('k_cache[1, 6, 0, :4]', k_cache[1, 6, 0, :4]) + #print('k_cache[0, 17, 0, :4]', k_cache[0, 17, 0, :4]) + #print('k_cache[1, 22, 0, :4]', k_cache[1, 22, 0, :4]) + #print('k_cache[2, 14, 0, :4]', k_cache[2, 14, 0, :4]) + #print('k_cache[3, 12, 0, :4]', k_cache[3, 12, 0, :4]) #print(k_cache1[0, :2, 0, :4]) #print(k_cache1[1, :2, 0, :4]) #print(k_cache[0, :2, 0, :4]) diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index e26ca22d5f..95d9b0c02c 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -4,11 +4,13 @@ """Paged KV Cache Manager.""" from collections import defaultdict, OrderedDict -from typing import List, Optional +from typing import List, Optional, Dict import logging import torch +import transformer_engine_torch as tex from transformer_engine.pytorch.kv_cache_manager_non_paged import KVCacheManager +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat class Page: @@ -45,7 +47,7 @@ def __init__( max_batch_size: int, max_seqlen: int, head_dim_v: Optional[int] = None, - is_cuda_graph: bool = False, + #is_cuda_graph: bool = False, ): """Initialize the KV cache""" self.total_num_pages = total_num_pages @@ -57,7 +59,7 @@ def __init__( self.max_seqlen = max_seqlen self.max_pages_per_seq = max_seqlen // self.page_size self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - self.is_cuda_graph = is_cuda_graph + #self.is_cuda_graph = is_cuda_graph # sequences contained in the kv cache, {seq_id: seq_len} self.sequences = OrderedDict() @@ -89,12 +91,11 @@ def allocate_memory(self, layer_number): device=torch.cuda.current_device(), ) self.cache[layer_number] = (k_cache, v_cache) + self.page_table = torch.zeros( + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + ) for i in range(self.total_num_pages): self.free_pages.append(Page(i)) - if self.is_cuda_graph: - self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" - ) def print_cache(self): """Print KV cache status""" @@ -152,10 +153,11 @@ def get_page_table(self, sequences: List[int]): for seq in sequences ] ).to(dtype=torch.int32, device="cpu") - if self.is_cuda_graph: - self.page_table[: self.get_sequence_count()].copy_(page_table) - else: - self.page_table = page_table.to(device="cuda") + self.page_table[: self.get_sequence_count()].copy_(page_table) + #if self.is_cuda_graph: + # self.page_table[: self.get_sequence_count()].copy_(page_table) + #else: + # self.page_table = page_table.to(device="cuda") return self.page_table def allocate_page(self, seq: int): @@ -182,12 +184,47 @@ def deallocate_sequence(self, seq: int): self.free_pages.append(page) self.allocated_pages.pop(seq) + def prepare( + self, + sequences: Dict[List, List], + step_dict: Dict[List, List], + ): + self.sequences = sequences + self.step_dict = step_dict + batch_size = len(step_dict) + step_lens = list(step_dict.values()) + cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] + + # Remove finished sequences and advance unfinished sequences + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + for seq in finished_seqs: + self.sequences.pop(seq) + self.deallocate_sequence(seq) + for seq in unfinished_seqs: + if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: + self.allocate_page(seq) + self.sequences[seq] += 1 + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for seq in new_seqs: + self.sequences[seq] = step_dict[seq] + self.allocate_sequence(seq, step_dict[seq]) + + # Get page table + self.page_table = self.get_page_table(list(self.sequences.keys())) + + return self.sequences + def step( self, layer_number: int, k: torch.Tensor, v: torch.Tensor, - step_dict: OrderedDict, + #step_dict: OrderedDict, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, qkv_format: str, ): """ @@ -214,72 +251,96 @@ def step( v_cache: torch.Tensor The value cache tensor containing previous and the current tokens """ - batch_size = len(step_dict) - step_lens = list(step_dict.values()) - cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] + #batch_size = len(step_dict) + #step_lens = list(step_dict.values()) + #cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] - # Remove finished sequences and advance unfinished sequences - unfinished_seqs = self.sequences.keys() & step_dict.keys() - finished_seqs = self.sequences.keys() - unfinished_seqs - for seq in finished_seqs: - self.sequences.pop(seq) - self.deallocate_sequence(seq) - for seq in unfinished_seqs: - if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: - self.allocate_page(seq) - self.sequences[seq] += 1 + ## Remove finished sequences and advance unfinished sequences + #unfinished_seqs = self.sequences.keys() & step_dict.keys() + #finished_seqs = self.sequences.keys() - unfinished_seqs + #for seq in finished_seqs: + # self.sequences.pop(seq) + # self.deallocate_sequence(seq) + #for seq in unfinished_seqs: + # if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: + # self.allocate_page(seq) + # self.sequences[seq] += 1 - # Add new sequences - new_seqs = step_dict.keys() - self.sequences.keys() - for seq in new_seqs: - self.sequences[seq] = step_dict[seq] - self.allocate_sequence(seq, step_dict[seq]) + ## Add new sequences + #new_seqs = step_dict.keys() - self.sequences.keys() + #for seq in new_seqs: + # self.sequences[seq] = step_dict[seq] + # self.allocate_sequence(seq, step_dict[seq]) - # Copy new key and value tenosrs to the cache - seqlens = list(self.sequences.values()) - packed_k = torch.Tensor([]).to(dtype=k.dtype, device=k.device) - packed_v = torch.Tensor([]).to(dtype=v.dtype, device=v.device) - for i in range(batch_size): - if qkv_format == "bshd": - packed_k = torch.cat([packed_k, k[i, : step_lens[i], :, :]], dim=0) - packed_v = torch.cat([packed_v, v[i, : step_lens[i], :, :]], dim=0) - if qkv_format == "sbhd": - packed_k = torch.cat([packed_k, k[: step_lens[i], i, :, :]], dim=0) - packed_v = torch.cat([packed_v, v[: step_lens[i], i, :, :]], dim=0) - if qkv_format == "thd": - packed_k = k - packed_v = v + ## Copy new key and value tenosrs to the cache + #seqlens = list(self.sequences.values()) + #packed_k = torch.Tensor([]).to(dtype=k.dtype, device=k.device) + #packed_v = torch.Tensor([]).to(dtype=v.dtype, device=v.device) + #for i in range(batch_size): + # if qkv_format == "bshd": + # packed_k = torch.cat([packed_k, k[i, : step_lens[i], :, :]], dim=0) + # packed_v = torch.cat([packed_v, v[i, : step_lens[i], :, :]], dim=0) + # if qkv_format == "sbhd": + # packed_k = torch.cat([packed_k, k[: step_lens[i], i, :, :]], dim=0) + # packed_v = torch.cat([packed_v, v[: step_lens[i], i, :, :]], dim=0) + #if qkv_format == "thd": + # packed_k = k + # packed_v = v k_cache, v_cache = self.cache[layer_number] - for i, seq in enumerate(step_dict.keys()): - page_list = self.get_page_list(seq) - start_page, start_token = self.get_page_token_offsets(seqlens[i] - step_lens[i]) - end_page, end_token = self.get_page_token_offsets(seqlens[i]) - if start_page == end_page: - page_id = page_list[start_page] - k_cache[page_id, start_token:end_token, :, :] = packed_k[ - cu_seqlens[i] : cu_seqlens[i + 1], :, : - ] - v_cache[page_id, start_token:end_token, :, :] = packed_v[ - cu_seqlens[i] : cu_seqlens[i + 1], :, : - ] - else: - start_offset = 0 - end_offset = 0 - for j in range(start_page, end_page + 1): - if not (j == end_page and end_token == 0): - start_token_j = start_token if j == start_page else 0 - end_token_j = end_token if j == end_page else self.page_size - page_id = page_list[start_page] - end_offset = end_token_j - start_token_j - k_cache[page_id, start_token_j:end_token_j, :, :] = packed_k[ - cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : - ] - v_cache[page_id, start_token_j:end_token_j, :, :] = packed_v[ - cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : - ] - start_offset = start_offset + end_offset + step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + #h=self.num_heads #16 + #d=self.head_dim_k #64 + #b=self.max_batch_size #4 + max_ctx_len=1 #k.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + if qkv_format == "bshd": + max_ctx_len=k.shape[1] + if qkv_format == "sbhd": + max_ctx_len=k.shape[0] + max_seq_len=self.max_seqlen #k_cache.shape[1] #64 #128 + max_ctx_tokens=k.shape[0] + max_tokens=k_cache.shape[0]*k_cache.shape[1] + print('kv shapes ', [x.shape for x in [k, v, k_cache, v_cache]]) + #print('step_lens ', step_lens) + #print('seq_lens ', seq_lens) + #print('self.batch_indices ', self.batch_indices) + print('lensss ', max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + tex.copy_to_kv_cache( + k, v, k_cache, v_cache, + self.page_table, step_lens, seq_lens, + QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, len(self.step_dict), #self.max_batch_size, + max_ctx_len, max_seq_len, self.max_pages_per_seq, False) - # Get page table - page_table = self.get_page_table(list(self.sequences.keys())) + #for i, seq in enumerate(step_dict.keys()): + # page_list = self.get_page_list(seq) + # start_page, start_token = self.get_page_token_offsets(seqlens[i] - step_lens[i]) + # end_page, end_token = self.get_page_token_offsets(seqlens[i]) + # if start_page == end_page: + # page_id = page_list[start_page] + # k_cache[page_id, start_token:end_token, :, :] = packed_k[ + # cu_seqlens[i] : cu_seqlens[i + 1], :, : + # ] + # v_cache[page_id, start_token:end_token, :, :] = packed_v[ + # cu_seqlens[i] : cu_seqlens[i + 1], :, : + # ] + # else: + # start_offset = 0 + # end_offset = 0 + # for j in range(start_page, end_page + 1): + # if not (j == end_page and end_token == 0): + # start_token_j = start_token if j == start_page else 0 + # end_token_j = end_token if j == end_page else self.page_size + # page_id = page_list[start_page] + # end_offset = end_token_j - start_token_j + # k_cache[page_id, start_token_j:end_token_j, :, :] = packed_k[ + # cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : + # ] + # v_cache[page_id, start_token_j:end_token_j, :, :] = packed_v[ + # cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : + # ] + # start_offset = start_offset + end_offset + + ## Get page table + #page_table = self.get_page_table(list(self.sequences.keys())) - return k_cache, v_cache, page_table + return k_cache, v_cache, self.page_table From ba5e3330fff3a3c6f5fefa921d5bd1fc7eccd806 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 14 Feb 2025 21:31:35 -0800 Subject: [PATCH 086/239] WIP: padding + BRCM Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 4 ++-- transformer_engine/common/fused_attn/fused_attn.cpp | 2 +- transformer_engine/pytorch/attention.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 76c1feb7f1..046cb5e56e 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -389,7 +389,7 @@ def gen_cu( #cu_dict["max_seqlen_q"] = model_config.max_seqlen_q #cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv cu_dict["inference_params"] = inference_params - cu_dict["attn_mask_type"] = "padding" #"causal" + cu_dict["attn_mask_type"] = "padding_causal" #"causal" # for qkv_format = thd cu_dict["max_seqlen_q"] = model_config.max_ctx_len #max_seqlen_q_infer cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv @@ -565,7 +565,7 @@ def gen_cu( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, - attn_mask_type="padding", + attn_mask_type="padding_causal", max_seqlen_q=max_seqlen_q, #config.max_ctx_len, #max_seqlen_q_infer, max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1e00cf1a19..6df4aee315 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -228,7 +228,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) (cudnn_runtime_version >= 90600 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + //max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ad316540fc..ef8255359b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7872,8 +7872,8 @@ def forward( # so users can run with the same attn_mask_type for training and inference if "padding" not in attn_mask_type: attn_mask_type = "padding_" + attn_mask_type -# if attn_mask_type in ["causal", "padding_causal"]: -# attn_mask_type = attn_mask_type + "_bottom_right" + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" # convert to cross attention type when KV cache is in use self.attention_type = "cross" From bcef6b34f36f1fc5f6e3ef05aa6cd1793674c337 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 15 Feb 2025 12:28:24 -0800 Subject: [PATCH 087/239] WIP: restructure IP, clean up Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 226 ++----- transformer_engine/pytorch/attention.py | 594 +----------------- .../pytorch/csrc/extensions/attention.cu | 6 - transformer_engine/pytorch/graph.py | 9 - .../pytorch/kv_cache_manager_non_paged.py | 172 +---- .../pytorch/kv_cache_manager_paged.py | 104 +-- transformer_engine/pytorch/utils.py | 16 - 7 files changed, 102 insertions(+), 1025 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 046cb5e56e..4fa27c907b 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -109,22 +109,11 @@ def reset(self): self.serving_times = self.arrival_times self.complete_times = self.arrival_times - # time-stepping workflow - # t-1: ... - # compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4], - # batch_size = 3, step_lens = [1, 1, 1] - # increase counter for gen_lens = [3, 10, 5] - # t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15] - # add two new seqs 3 and 4, with ctx lens 10 and 11 - # compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0], - # batch_size = 4, step_lens = [1, 1, 10, 11] - # increase counter for gen_lens = [3, 5, 1, 1] - # batch info at step t self.t_seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu") self.t_ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") self.t_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") - self.t_total_lens = self.t_ctx_lens + self.t_gen_lens #+ self.step_lens + self.t_total_lens = self.t_ctx_lens + self.t_gen_lens self.t_batch_size = 0 # step info from step t-1 to t @@ -244,7 +233,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) # create model - # TODO: multi layers [num_layers] model = ( DotProductAttention( kv_channels=config.head_dim_qk, @@ -276,15 +264,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): dtype=dtype, device="cuda", ) - #print('k_full[0, 0, 0, :4]', k[0, 0, 0, :4]) - print('k_full[7, 46:48, 0, :4]', k[7, 46:48, 0, :4]) - #print('k_full[1, :2, 0, :4]', k[1, :2, 0, :4]) - #print('k_full[1, 6, 0, :4]', k[1, 6, 0, :4]) - #print('k_full[0, 17, 0, :4]', k[0, 17, 0, :4]) - #print('k_full[2, 22, 0, :4]', k[2, 22, 0, :4]) - #print('k_full[5, 14, 0, :4]', k[5, 14, 0, :4]) - #print('k_full[6, 12, 0, :4]', k[6, 12, 0, :4]) - # generate reference results logger.info("=== Generating all tokens at once ===") @@ -316,6 +295,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): ) sim.print_setup(logger) + # initialize inference_params inference_params = InferenceParams( max_batch_size=max_batch_size, max_seqlen_kv=config.max_seqlen_kv, @@ -331,97 +311,79 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, ) - # TODO: num_layers inference_params.allocate_memory(layer_number, qkv_format) - #inference_params.print() - - def generate_data( - model_config: ModelConfig, - dtype: torch.dtype, - warmup: bool = False, - qkv_format: str = "bshd", - ) -> List[torch.Tensor]: - """Generate synthetic data for dot product attention.""" - gen_func = torch.ones if warmup else torch.randn + + # graph the model if necessary + if is_cuda_graph: + t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") + step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu") + step_dict = OrderedDict( + zip(t_seq_ids.tolist(), step_lens.tolist()) + ) + inference_params.pre_step(step_dict) + if qkv_format == "bshd": - shape = [ model_config.batch_size, model_config.max_ctx_len] + shape = [ config.batch_size, config.max_ctx_len] if qkv_format == "sbhd": - shape = [ model_config.max_ctx_len, model_config.batch_size] + shape = [ config.max_ctx_len, config.batch_size] if qkv_format == "thd": - shape = [ model_config.batch_size * model_config.max_ctx_len] - aa=[ - gen_func( - #model_config.max_ctx_len, - #model_config.batch_size, + shape = [ config.batch_size * config.max_ctx_len] + def gen_data(): + return [torch.ones( *shape, - model_config.num_heads, - model_config.head_dim_qk, + config.num_heads, + config.head_dim_qk, device="cuda", - #requires_grad=True, dtype=dtype, - ) - for _ in range(3) - ] - #print(aa[0].shape, aa[0][8,0,:4]) - #aa.extend([model_config.sequence_length, model_config.sequence_length]) - return aa - - def gen_cu( - model_config: ModelConfig, - dtype: torch.dtype, - ): - cu_dict = {} - cu_dict["cu_seqlens_q"] = torch.linspace( 0, - model_config.batch_size * model_config.max_ctx_len, - #model_config.batch_size * model_config.max_seqlen_q, - steps=model_config.batch_size+1, + ) for _ in range(3)] + + sample_kwargs = {} + sample_kwargs["cu_seqlens_q"] = torch.linspace( 0, + config.batch_size * config.max_ctx_len, + steps=config.batch_size+1, device="cuda", dtype=torch.int32, ) - cu_dict["cu_seqlens_kv"] = torch.linspace( 0, - model_config.batch_size * model_config.max_ctx_len, - #model_config.batch_size * 1, #model_config.max_seqlen_kv, - #model_config.batch_size * model_config.max_seqlen_kv, - steps=model_config.batch_size+1, + sample_kwargs["cu_seqlens_kv"] = torch.linspace( 0, + config.batch_size * config.max_ctx_len, + steps=config.batch_size+1, device="cuda", dtype=torch.int32, ) - #cu_dict["max_seqlen_q"] = model_config.max_seqlen_q - #cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv - cu_dict["inference_params"] = inference_params - cu_dict["attn_mask_type"] = "padding_causal" #"causal" - # for qkv_format = thd - cu_dict["max_seqlen_q"] = model_config.max_ctx_len #max_seqlen_q_infer - cu_dict["max_seqlen_kv"] = model_config.max_seqlen_kv - cu_dict["qkv_format"] = qkv_format - return cu_dict - - t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") - step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu") - step_dict = OrderedDict( - zip(t_seq_ids.tolist(), step_lens.tolist()) - ) - inference_params.prepare(step_dict) - if is_cuda_graph: + sample_kwargs["inference_params"] = inference_params + sample_kwargs["attn_mask_type"] = "padding_causal" + sample_kwargs["max_seqlen_q"] = config.max_ctx_len + sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv + sample_kwargs["qkv_format"] = qkv_format + model = make_graphed_callables( model, - generate_data(config, dtype, warmup=True, qkv_format=qkv_format), - num_warmup_iters=3, #10, + gen_data(), + num_warmup_iters=10, fp8_enabled=False, - #sample_kwargs={"qkv_format":"thd"}, - sample_kwargs=gen_cu(config, dtype), + sample_kwargs=sample_kwargs, ) - print('AAAAAAAAAAAAfter graphed') - # similate step by step - sim.reset() - inference_params.reset() - graphed = False - model_orig = model + + sim.reset() + inference_params.reset() + step_dict = OrderedDict() + + # simulate step by step + # t-1: ... + # compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4], + # batch_size = 3, step_lens = [1, 1, 1] + # increase counter for gen_lens = [3, 10, 5] + # t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15] + # add two new seqs 3 and 4, with ctx lens 10 and 11 + # compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0], + # batch_size = 4, step_lens = [1, 1, 10, 11] + # increase counter for gen_lens = [3, 5, 1, 1] max_tokens = config.batch_size * config.max_ctx_len while True: if inference_params.is_paged: inference_params.cache_manager.print_cache() + # prepare batch for the current step dynamic_fill = True #inference_params.is_paged sim.step(dynamic_fill=dynamic_fill) sim.print_step(logger) @@ -437,15 +399,9 @@ def gen_cu( sim.t += 1 continue - #if not is_cuda_graph: - # max_seqlen_q_infer = int((sim.max_ctx_len + 63)// 64 * 64) - #else: - # max_seqlen_q_infer = max_seqlen_kv_roundup - - batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size - max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() #max_seqlen_q_infer, - #max_seqlen_q_infer = sim.max_ctx_len # create incremental input + batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size + max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() if qkv_format == "thd": incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") incremental_k = torch.Tensor().to(dtype=dtype, device="cuda") @@ -469,14 +425,12 @@ def gen_cu( dim=0, ) if is_cuda_graph: - print('incremental qkv shapes ', [x.shape for x in [incremental_q, incremental_k, incremental_v]]) incremental_q = torch.cat([incremental_q, torch.zeros([max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], dtype=dtype, device=incremental_q.device)], dim=0) incremental_k = torch.cat([incremental_k, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_k.device)], dim=0) incremental_v = torch.cat([incremental_v, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_v.device)], dim=0) else: incremental_q = torch.zeros( batch_size, - #sim.max_ctx_len, #max_seqlen_q_infer, max_seqlen_q, config.num_heads, config.head_dim_qk, @@ -485,7 +439,6 @@ def gen_cu( ) incremental_k = torch.zeros( batch_size, - #sim.max_ctx_len, #max_seqlen_q_infer, max_seqlen_q, config.num_gqa_groups, config.head_dim_qk, @@ -493,9 +446,7 @@ def gen_cu( device="cuda", ) incremental_v = torch.zeros( - #sim.t_batch_size, batch_size, - #sim.max_ctx_len, #max_seqlen_q_infer, max_seqlen_q, config.num_gqa_groups, config.head_dim_v, @@ -513,52 +464,17 @@ def gen_cu( x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] ] + # run step batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) - print('qkv_format' ,qkv_format, cu_seqlens_q, cu_seqlens_kv) - #print("q[1, 8:10, :2, :2]", q[1, 8:10, :2, :2]) - #print("inc_q[18:20, :2, :2]", incremental_q[18:20, :2, :2]) - step_dict = OrderedDict( zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) ) - inference_params.prepare(step_dict) - - #if sim.step_lens[0] == 1 and not graphed: - # model_graphed = make_graphed_callables( - # model, - # generate_data(config, dtype, warmup=True), - # num_warmup_iters=10, - # fp8_enabled=False, - # #sample_kwargs={"qkv_format":"thd"}, - # sample_kwargs=gen_cu(config, dtype), - # ) - # graphed = True - # print('AAAAAAAAAAAAfter graphed') - #if not graphed: - # model = make_graphed_callables( - # model, - # generate_data(config, dtype, warmup=True), - # num_warmup_iters=10, - # fp8_enabled=False, - # #sample_kwargs={"qkv_format":"thd"}, - # sample_kwargs=gen_cu(config, dtype), - # ) - # graphed = True - # print('AAAAAAAAAAAAfter graphed') - print('incremental shapes', [x.shape for x in [ incremental_q, incremental_k, incremental_v]]) - - #if sim.step_lens[0] == 1 and graphed: - # model = model_graphed - #else: - # model = model_orig + inference_params.pre_step(step_dict) line_output = model( - #query_layer=incremental_q, - #key_layer=incremental_k, - #value_layer=incremental_v, incremental_q, incremental_k, incremental_v, @@ -566,12 +482,12 @@ def gen_cu( cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, attn_mask_type="padding_causal", - max_seqlen_q=max_seqlen_q, #config.max_ctx_len, #max_seqlen_q_infer, + max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, ) - print('llllllllllllllll ', line_output.shape) + # compare results if backend != "FlashAttention": tols = { torch.float32: 1e-3, @@ -586,35 +502,23 @@ def gen_cu( } for i, seq in enumerate(sim.t_seq_ids): if qkv_format == "bshd": - print(i,seq, sim.t_total_lens[i], sim.step_lens[i]) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( - full_output[seq, sim.t_total_lens[i] - 1, :], - line_output[i, sim.step_lens[i] - 1, :], + full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + line_output[i, :sim.step_lens[i] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "sbhd": - print(i,seq, sim.t_total_lens[i], sim.step_lens[i]) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[sim.step_lens[i] - 1, i, :4]) torch.testing.assert_close( - full_output[seq, sim.t_total_lens[i] - 1, :], - line_output[sim.step_lens[i] - 1, i, :], + full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + line_output[:sim.step_lens[i] - 1, i, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "thd": - print('iiii ', i, cu_seqlens_q, sim.t_total_lens) - print('thd ', seq, sim.t_total_lens[i], cu_seqlens_q[i + 1]) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[cu_seqlens_q[i + 1] - 1, :4]) - #print(line_output[cu_seqlens_q[1 + 1] - 1, :4]) - #print(line_output[cu_seqlens_q[2 + 1] - 1, :4]) torch.testing.assert_close( - full_output[seq, sim.t_total_lens[i] - 1, :], - line_output[cu_seqlens_q[i + 1] - 1, :], + full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ef8255359b..fffb44c1c7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -57,7 +57,6 @@ split_tensor_along_dim, get_device_compute_capability, get_default_init_method, - StaticBufferAllocator, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -80,8 +79,7 @@ ) from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing -from transformer_engine.pytorch.kv_cache_manager_paged import PagedKVCacheManager -from transformer_engine.pytorch.kv_cache_manager_non_paged import NonPagedKVCacheManager +from transformer_engine.pytorch.inference import InferenceParams from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, prepare_for_saving, @@ -1027,533 +1025,6 @@ def get_attention_backend( available_backends, ) -class KVCacheManager: - """ - KV cache manager. This should be the base class for custom KV cache managers. - """ - def __init__(self, *args, **kwargs): - """Initialize the cache manager""" - self.cache = {} - def allocate_memory(self, layer_number: int): - """Allocate memory for the cache""" - self.cache[layer_number] = (None, None) - def prepare( - self, - sequences: Dict[List, List], - step_dict: Dict[List, List], - ): - """Prepare for step(). Update sequences with step_dict.""" - return sequences - def step( - self, - layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - qkv_format: str, - ): - """Update the cache with new_k and new_v tokens""" - return *self.cache[layer_number], None - -class InferenceParams: # pylint: disable=too-few-public-methods - """ - Inference parameters that are passed to the main model in order - to efficiently calculate and store the context and previously generated tokens - during inference. - - Parameters - ---------- - max_batch_size : int - maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. - num_heads: int - number of attention heads in key/value tensor. - head_dim_k: int - head size for the key tensor. - dtype: torch.dtype - data type for the KV cache. - head_dim_v: Optional[int], default = None - head size for the value tensor. If None, it will be set to head_dim_k. - is_paged: bool, default = False - whether the KV cache is paged or non-paged (contiguous). - total_num_pages: Optional[int], default = None - total number of pages in the K cache or V cache if is_paged = True. - page_size: Optional[int], default = None - page size in number of tokens if is_paged = True. - """ - - def __init__( - self, - max_batch_size: int, - max_seqlen_kv: int, - num_heads_kv: int, - head_dim_k: int, - dtype: torch.dtype, - head_dim_v: int = None, - is_paged: bool = False, - total_num_pages: int = None, - page_size: int = None, - num_heads_q: int = None, - head_dim_q: int = None, - max_ctx_len: int = None, - qkv_format: str = "bshd", - cache_manager: KVCacheManager = None, - ): - self.max_batch_size = max_batch_size - self.max_seqlen_kv = max_seqlen_kv - self.num_heads_kv = num_heads_kv - self.head_dim_k = head_dim_k - self.dtype = dtype - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - self.is_paged = is_paged - #self.page_table = None - - if not self.is_paged: - cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager - self.cache_manager = cls( - max_batch_size=self.max_batch_size, - max_seqlen=self.max_seqlen_kv, - num_heads=self.num_heads_kv, - head_dim_k=self.head_dim_k, - dtype=self.dtype, - head_dim_v=self.head_dim_v, - ) - else: - assert page_size is not None, "Paged KV cache requires page_size!" - assert max_seqlen_kv % page_size == 0, "Paged KV cache requires max_seqlen_kv % page_size = 0!" - max_pages_per_seq = max_seqlen_kv // page_size - assert ( - total_num_pages == self.max_batch_size * max_pages_per_seq - ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq!" - self.page_size = page_size - #self.max_seqlen_kv = ( - # self.max_seqlen_kv - # if self.max_seqlen_kv >= self.page_size - # else int( - # (self.max_seqlen_kv + self.page_size - 1) // self.page_size * self.page_size - # ) - #) - self.max_seqlen_kv = max_seqlen_kv - self.total_num_pages = total_num_pages - cls = cache_manager if cache_manager is not None else PagedKVCacheManager - self.cache_manager = cls( - total_num_pages=self.total_num_pages, - page_size=self.page_size, - num_heads=self.num_heads_kv, - head_dim_k=self.head_dim_k, - dtype=self.dtype, - max_batch_size=self.max_batch_size, - max_seqlen=self.max_seqlen_kv, - head_dim_v=self.head_dim_v, - ) - - if qkv_format == "thd": - assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" - assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" - assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" - self.num_heads_q = num_heads_q - self.head_dim_q = head_dim_q - self.max_ctx_len = max_ctx_len - - self.input_qkv_format = "bshd" - # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache - self.cache_qkv_format = "bshd" - - # layer numbers that we have kv cache for - #self.layer_numbers = [] - # sequence ids that are stored in the cache - #self.seq_ids = [] - # the full sequence lengths for sequences in seq_ids - #self.seq_lens = [] #0] * self.max_batch_size - self.sequences = collections.OrderedDict() #zip(self.seq_ids, self.seq_lens)) - # the {seq_id: step_len} information for a new inference step - # e.g. inference_params.step_dict = {2: 1, 3: 1, 4: 10}, if in this iteration, - # we have three sequences in the batch: sequences 2 and 3 are in generation phase - # with step_len = 1 and sequence 4 is in context phase with 10 new tokens - #self.step_lens = [] - - # TODO: needed? - self.step_dict = collections.OrderedDict() - - # the query buffer when is_cuda_graph = True - #if self.is_cuda_graph: - # self.q_buffer = {} - # self.cu_seqlens_q_buffer = [] - # self.cu_seqlens_kv_buffer = [] - - #def print(self): - # """Print InferenceParams parameters""" - # logger = logging.getLogger("InferenceParams") - # logger.debug("InferenceParams:") - # logger.debug(" dtype: %s", self.dtype) - # logger.debug(" is_paged: %s", self.is_paged) - # if not self.is_paged: - # logger.debug(" max_batch_size: %s", self.max_batch_size) - # logger.debug(" max_seqlen_kv: %s", self.max_seqlen_kv) - # else: - # logger.debug(" total_num_pages: %s", self.total_num_pages) - # logger.debug(" page_size: %s", self.page_size) - # logger.debug(" num_heads_kv: %s", self.num_heads_kv) - # logger.debug(" head_dim: k: %s, v: %s", self.head_dim_k, self.head_dim_v) - # #logger.debug(" layer_numbers: %s", self.layer_numbers) - - def allocate_memory(self, layer_number: int, qkv_format: str): - """ - Allocate memory for the KV cache for the layer #layer_number. - Both K cache and V cache are in 'bshd' format. - - non-paged: - - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] - - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] - - paged: - - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] - - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] - If is_cuda_graph = True, several buffers are also allocated. - - Q buffer: [max_batch_size, max_seqlen_kv, num_heads_q, head_dim_q] - - cu_seqlens_q buffer: [max_batch_size + 1] - - cu_seqlens_kv buffer: [max_batch_size + 1] - """ - #self.layer_numbers.append(layer_number) - - self.cache_manager.allocate_memory(layer_number) - if qkv_format == 'thd': #self.is_cuda_graph: - #self.max_seqlen_q = self.max_seqlen_kv - self.q_orig = {} - self.q_buffer = {} - self.q_buffer[layer_number] = torch.zeros( - self.max_batch_size, - self.max_ctx_len, - self.num_heads_q, - self.head_dim_q, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - self.q_dummy = torch.Tensor().to(dtype=self.dtype, device="cuda") - self.batch_indices = torch.Tensor().to(dtype=torch.int32, device="cuda") - self.cu_seqlens_q = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - self.cu_seqlens_kv = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - def reset(self): - #self.cu_seqlens_q.fill_(0) - #self.cu_seqlens_kv.fill_(0) - self.sequences = collections.OrderedDict() #zip(self.seq_ids, self.seq_lens)) - self.step_dict = collections.OrderedDict() - - def prepare( - self, - step_dict: Dict[List, List], - ): - self.sequences = self.cache_manager.prepare(self.sequences, step_dict) - - self.step_dict = step_dict - - actual_batch_size = len(self.step_dict) - seqlens_q = list(self.step_dict.values()) - cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * ( - self.max_batch_size - actual_batch_size - ) - self.seq_lens = list(self.sequences.values()) - - #self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( - self.cu_seqlens_q.copy_( - torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") - ) - cu_seqlens_kv = [0] + [sum(self.seq_lens[:i]) for i in range(1, actual_batch_size + 1)] - cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( - self.max_batch_size - actual_batch_size - ) - self.cu_seqlens_kv.copy_( - torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") - ) - -# def reshape_and_copy_q( -# self, -# q: torch.Tensor, -# source_qkv_format: str, -# target_qkv_format: str, # pylint: disable=unused-argument -# layer_number: Optional[int] = None, -# ): -# """ -# Convert the new query tokens from 'source_qkv_format' to 'target_qkv_format', -# so that it is consistent with the KV cache format. At the moment, only 'bshd' format -# is supported for target_qkv_format. If is_cuda_graph = True, also copy the new query -# tensor to the appropriate q_buffer. -# """ -# actual_batch_size = len(self.step_dict) -# seqlens_q = list(self.step_dict.values()) -# cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] -# batch_wide_max_seqlen_q = int((max(seqlens_q) + 63) // 64 * 64) -# if not self.is_cuda_graph: -# if source_qkv_format == "bshd": -# q = q.contiguous() -# if source_qkv_format == "sbhd": -# q = q.transpose(0, 1).contiguous() -# if source_qkv_format == "thd": -# padded_q = torch.zeros( -# actual_batch_size, -# batch_wide_max_seqlen_q, -# q.shape[-2], -# q.shape[-1], -# dtype=q.dtype, -# device="cuda", -# ) -# for i in range(actual_batch_size): -# padded_q[i, : seqlens_q[i], :, :] = q[ -# cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, : -# ] -# q = padded_q -# -# if source_qkv_format in ["bshd", "sbhd"]: -# self.max_seqlen_q = q.shape[1] -# else: -# self.max_seqlen_q = batch_wide_max_seqlen_q -# -# # bshd: [actual_batch_size, batch_wide_max_seqlen_q, num_heads_q, head_dim_q] -# return q -# -# assert ( -# layer_number is not None and layer_number in self.layer_numbers -# ), "layer_number must be an integer and must exist in InferenceParams.layer_numbers!" -# q_buffer = self.q_buffer[layer_number] -# for i in range(actual_batch_size): -# if source_qkv_format == "bshd": -# q_buffer[i, : seqlens_q[i], :, :] = q[i, : seqlens_q[i], :, :] -# if source_qkv_format == "sbhd": -# q_buffer[i, : seqlens_q[i], :, :] = q[: seqlens_q[i], i, :, :] -# if source_qkv_format == "thd": -# q_buffer[i, : seqlens_q[i], :, :] = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1], :, :] -# q_buffer[i, seqlens_q[i] :, :, :].fill_(0) -# -# cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) -# self.cu_seqlens_q_buffer.copy_( -# torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") -# ) -# -# # bshd: [self.max_batch_size, self.max_seqlen_kv, num_heads_q, head_dim_q] -# return q_buffer - - def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): - """ - Convert the k cache and v cache from paged to non-paged format. This function - can be used for debugging purposes or for backends that do not have paged attention - support yet, for example, UnfusedDotProductAttention. - - It can be called after update_cache(). Based on the page table, it re-indexes the cache - tensors and returns the contiguous, non-paged, key and value tensors. The kv cache tensors - are assumed to be in 'bshd' format (see self.allocate_memory), and the returned key and - value tensors will be in :attr:`qkv_format` to be consistent with the original inputs. - - Parameters - ---------- - layer_number: int - The layer number of the kv cache - qkv_format: str - The format of the returned key and value tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - Non-paged key cache tensor - v_cache: torch.Tensor - Non-paged value cache tensor - """ - k_cache, v_cache = self.cache_manager.cache[layer_number] - page_table = self.cache_manager.page_table - batch_size = page_table.shape[0] - actual_batch_size = len(self.step_dict) - new_k_cache = rearrange( - k_cache[page_table.flatten()], - "(b npages) page_size ... -> b (npages page_size) ...", - b=batch_size, - ) - new_v_cache = rearrange( - v_cache[page_table.flatten()], - "(b npages) page_size ... -> b (npages page_size) ...", - b=batch_size, - ) - for i in range(actual_batch_size): - new_k_cache[i, self.seqlens[i] :, :, :].fill_(0) - new_v_cache[i, self.seqlens[i] :, :, :].fill_(0) - if qkv_format == "bshd": - new_k_cache = new_k_cache.contiguous() - new_v_cache = new_v_cache.contiguous() - if qkv_format == "sbhd": - new_k_cache = new_k_cache.transpose(0, 1).contiguous() - new_v_cache = new_v_cache.transpose(0, 1).contiguous() - if qkv_format == "thd": - packed_k_cache = torch.Tensor().to(dtype=k_cache.dtype, device=k_cache.device) - packed_v_cache = torch.Tensor().to(dtype=v_cache.dtype, device=v_cache.device) - for i in range(batch_size): - packed_k_cache = torch.cat( - [packed_k_cache, new_k_cache[i, : self.seqlens[i], :, :]], dim=0 - ) - packed_v_cache = torch.cat( - [packed_v_cache, new_v_cache[i, : self.seqlens[i], :, :]], dim=0 - ) - new_k_cache = packed_k_cache.contiguous() - new_v_cache = packed_v_cache.contiguous() - return new_k_cache, new_v_cache - - def update_cache( - self, - layer_number: int, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qkv_format: str, - ): - """ - Update KV cache with the new key/value tokens for a given inference iteration. - - NonPagedKVCacheManager and PagedKVCacheManager are two examples of the cache manager. - Users can write their own cache manager with their own step() function. - - If the inference iteration has only generation sequences, :attr:`k` and :attr:`v` tensors - should have shape: - - [batch_size, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', - - [1, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and - - [batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. - - If the inference iteration has both generation sequences and context sequences, :attr:`k` - and :attr:`v` should be arranged in a way so that the sequences in generation phase come - before the sequences in context phase, in the tensor. They should have the following shape. - - [batch_size, max_seqlen, num_heads, head_dim] for :attr:`qkv_format` = 'bshd' - - [max_seqlen, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and - - [total_num_new_tokens, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. - Here, max_seqlen is the maximum sequence length for the new tokens in the batch, and it may - be smaller than InferenceParams.max_seqlen_kv. - - Take a batch of 4, with seq_ids = [0, 1, 2, 3], as an example. At iteration t, all 4 sequences - are processed, after which, sequence 2 is determined to be 'finished'. For iteration t+1, there - may or may not be a new sequence added to the batch. - - If no new sequence is added, input tensors :attr:`k` and :attr:`v` should have shape - [3, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', [1, 3, num_heads, head_dim] for - :attr:`qkv_format` = 'sbhd', and [3, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. - - If one new sequence is added, for example, sequence 8 with 10 context tokens, then input tensors - :attr:`k` and :attr:`v` should be in [4, 10, num_heads, head_dim] shape if - :attr:`qkv_format` = 'bshd', [10, 4, num_heads, head_dim] if :attr:`qkv_format` = 'sbhd', - or [13, num_heads, head_dim] if :attr:`qkv_format` = 'thd'. - - Parameters - ---------- - layer_number: int - The layer number of the kv cache - k: torch.Tensor - The new key tokens for the current iteration - v: torch.Tensor - The new value tokens for the current iteration - qkv_format: str - The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - The key cache tensor, containing tokens from both previous and current iterations - v_cache: torch.Tensor - The value cache tensor, containing tokens from both previous and current iterations - page_table: torch.Tensor - The page table if is_paged = True; else `None` - """ - #actual_batch_size = len(self.step_dict) - #seqlens_q = list(self.step_dict.values()) - #cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] - #print('qkv_foramt', qkv_format) - #print('qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) - #print('qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) - seqlens_q = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - batch_size = len(seqlens_q) - if qkv_format == "bshd": - q_buffer = q.contiguous() - max_seqlen_q = q_buffer.shape[1] - if qkv_format == "sbhd": - q_buffer = q.transpose(0, 1).contiguous() - max_seqlen_q = q_buffer.shape[1] - if qkv_format == "thd": - #print('---qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) - #print('---qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) - self.q_orig[layer_number] = q - q_buffer = self.q_buffer[layer_number] - #q_buffer_copy = self.q_buffer[layer_number].clone() - ##for i in range(actual_batch_size): - #for i in range(batch_size): - # q_buffer[i, : seqlens_q[i], :, :] = q[self.cu_seqlens_q[i] : self.cu_seqlens_q[i + 1], :, :] - ##q = q_buffer - max_seqlen_q = self.max_ctx_len - - #max_seqlen_kv = self.max_seqlen_kv - step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - #seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] - max_ctx_len=q.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 - max_seq_len=self.max_ctx_len #q_buffer.shape[1] #64 #128 - #max_ctx_tokens=q.shape[0] - #max_tokens=q_buffer.shape[0]*q_buffer.shape[1] - #print('---++qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) - #print('---++qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) - #print(q_buffer.shape) - #print(self.cu_seqlens_q, self.cu_seqlens_kv, step_lens, seq_lens, QKVFormat[qkv_format]) - print('q xxxxxxxxxxxx ',self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, - max_ctx_len, max_seq_len)#, max_ctx_tokens, max_tokens) - # TODO: batch_indices - tex.reshape_q(q, q_buffer, step_lens, QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.max_batch_size, max_ctx_len, max_seq_len) - #tex.copy_to_kv_cache_non_paged( - # q, self.q_dummy, q_buffer, self.q_dummy, - # self.batch_indices, step_lens, step_lens, - # QKVFormat[qkv_format], self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, - # max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) - #q = q_buffer - #q_buffer = q_buffer_copy - #torch.save(q_buffer, 'q_buffer.pt') - #torch.save(q_buffer_copy, 'q_buffer_copy.pt') - #q = q_buffer - #print('qqqqqqqq', q_buffer.shape, q_buffer.dtype, q_buffer[:2, 8:10, 0, :4]) - - #self.page_table = page_table - #self.seq_ids = list(self.cache_manager.sequences.keys()) - #self.seqlens = list(self.cache_manager.sequences.values()) - self.seq_lens = list(self.sequences.values()) - #print('self.sequences',self.sequences) - #print(self.max_batch_size, actual_batch_size) - - #self.cu_seqlens_q[:len(cu_seqlens_q)].copy_( - # torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") - #) - #cu_seqlens_kv = [0] + [sum(self.seq_lens[:i]) for i in range(1, actual_batch_size + 1)] - #cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( - # self.max_batch_size - actual_batch_size - #) - #self.cu_seqlens_kv.copy_( - # torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") - #) - k_cache, v_cache, page_table = self.cache_manager.step( - #layer_number, k, v, self.step_dict, qkv_format, self.cu_seqlens_q, self.cu_seqlens_kv, - layer_number, k, v, self.cu_seqlens_q, self.cu_seqlens_kv, qkv_format, - ) - - #if self.is_cuda_graph: - # actual_batch_size = len(self.seqlens) - # cu_seqlens_kv = [0] + [sum(self.seqlens[:i]) for i in range(1, actual_batch_size + 1)] - # cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( - # self.max_batch_size - actual_batch_size - # ) - # self.cu_seqlens_kv_buffer.copy_( - # torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") - # ) - - # k_cache and v_cache are in InferenceParams.qkv_format format -# return k_cache, v_cache, page_table - return q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, max_seqlen_q, self.max_seqlen_kv - @torch.no_grad() def get_full_mask( @@ -7865,9 +7336,6 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" - # remember original format for output purposes - inference_params.input_qkv_format = qkv_format - # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference if "padding" not in attn_mask_type: @@ -7898,9 +7366,9 @@ def forward( # update KV cache and return the full key/value tensors # full key/value tensors are in inference_params.qkv_format format - print('query_layer',query_layer.shape, query_layer.dtype) + #print('query_layer',query_layer.shape, query_layer.dtype) #print('query_layer', query_layer[8,0,:4]) - query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = inference_params.update_cache( + query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, qkv_format = inference_params.step( self.layer_number, query_layer, key_layer, @@ -7917,10 +7385,6 @@ def forward( # max_seqlen_q = inference_params.max_seqlen_q # max_seqlen_kv = inference_params.max_seqlen_kv - # query tensor is now in inference_params.qkv_format - #qkv_format = target_qkv_format - qkv_format = inference_params.cache_qkv_format - cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -7954,8 +7418,8 @@ def forward( max_seqlen_kv, key_layer.device, ) - print('max_seqlen_q ', max_seqlen_q) - print('max_seqlen_kv ', max_seqlen_kv) + #print('max_seqlen_q ', max_seqlen_q) + #print('max_seqlen_kv ', max_seqlen_kv) #print('cu_seqlens_q ', cu_seqlens_q) #print('cu_seqlens_kv ', cu_seqlens_kv) @@ -7973,7 +7437,7 @@ def forward( ) # convert qkv layout to its corresponding paged attention layout if inference_params is not None and inference_params.is_paged: - qkv_layout = "paged_kv_" + qkv_format + "_2" + inference_params.cache_qkv_format + qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format global _alibi_cache if alibi_slopes is not None: @@ -8198,7 +7662,7 @@ def forward( fp8_meta=self.fp8_meta, quantizers=self.quantizers, ) - print('ooooooooooo ',output.shape) + #print('ooooooooooo ',output.shape) #print(output[1,9,:4]) #print(output[1,10,:4]) @@ -8245,49 +7709,7 @@ def forward( ) if inference_params is not None: - batch_size = len(inference_params.step_dict) - step_lens = list(inference_params.step_dict.values()) - max_seqlen_q = max(list(inference_params.step_dict.values())) - print('xxxxxxxxx ', batch_size, step_lens, max_seqlen_q, inference_params.step_dict, inference_params.input_qkv_format, output.shape) - #ooo = output.view(output.shape[:2], -1) - #print('output ', output[0,0,:4]) - #print('output ', output[1,0,:4]) - #print('output ', output[0,0,:4]) - #print('output ', output[1,6,:4]) - if inference_params.input_qkv_format == "bshd": - output = output[:batch_size, :max_seqlen_q].contiguous() - if inference_params.input_qkv_format == "sbhd": - output = output[:batch_size, :max_seqlen_q].transpose(0, 1).contiguous() - if inference_params.input_qkv_format == "thd": - output_buffer = inference_params.q_orig[self.layer_number] - #packed_output = torch.Tensor().to(dtype=output.dtype, device=output.device) - #for i in range(batch_size): - # packed_output = torch.cat([packed_output, output[i, : step_lens[i]]], dim=0) - #output = packed_output.contiguous() - - #max_seqlen_kv = self.max_seqlen_kv - #step_lens = inference_params.cu_seqlens_q[1:] - inference_params.cu_seqlens_q[:-1] - step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - #seq_lens = self.cu_seqlens_kv[1:] - self.cu_seqlens_kv[:-1] - max_ctx_len=1 #output.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 - max_seq_len=inference_params.max_ctx_len #q_buffer.shape[1] #64 #128 - #max_ctx_tokens=q.shape[0] - #max_tokens=q_buffer.shape[0]*q_buffer.shape[1] - #print('---++qqqqqqqq0', q.shape, q.dtype, q[8, 0, :4]) - #print('---++qqqqqqqq1', q.shape, q.dtype, q[18:20, 0, :4]) - #print(q_buffer.shape) - #print(self.cu_seqlens_q, self.cu_seqlens_kv, step_lens, seq_lens, QKVFormat[qkv_format]) - #print('o xxxxxxxxxxxx ',step_lens, #self.num_heads_q, self.head_dim_q, self.head_dim_q, self.max_batch_size, - # max_ctx_len, max_seq_len, output.shape, output_buffer.shape)#, max_ctx_tokens, max_tokens) - # TODO: batch_indices - tex.reshape_o(output, output_buffer, step_lens, - inference_params.num_heads_q, inference_params.head_dim_q, inference_params.max_batch_size, max_seq_len) #, max_ctx_tokens, max_tokens) - #tex.copy_to_kv_cache_non_paged( - # inference_params.q_dummy, output, inference_params.q_dummy, output_buffer, - # inference_params.batch_indices, step_lens, step_lens, - # QKVFormat[qkv_format], inference_params.num_heads_q, inference_params.head_dim_q, inference_params.head_dim_q, inference_params.max_batch_size, - # max_ctx_len, max_seq_len) #, max_ctx_tokens, max_tokens) - output = output_buffer.view(output_buffer.shape[0], -1) + output = inference_params.post_step(self.layer_number, output) return output diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index f168eb4178..5a518adf71 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -66,7 +66,6 @@ void reshape_q_launcher( torch::Tensor step_lens, NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { - printf("-------- 3 %p %d %d %d \n"); //, new_v.data_ptr(), h_kv, d_k, d_v); reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_q.data_ptr()), reinterpret_cast(q_buffer.data_ptr()), @@ -136,7 +135,6 @@ void reshape_o_launcher( torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, int h_o, int d_o, int b, int max_seq_len) { - printf("-------- 4 %p %d %d %d \n"); //, new_v.data_ptr(), h_kv, d_k, d_v); reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(output.data_ptr()), reinterpret_cast(output_buffer.data_ptr()), @@ -187,9 +185,6 @@ __global__ void reindex_kv_cache_kernel( actual_b = i+1; } } - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("actual_b is %d\n", actual_b); - } for (int batch_idx = 0; batch_idx < actual_b; batch_idx ++) { for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; token_idx += gridDim.x) { int num_elts_k = h_kv * d_k; @@ -295,7 +290,6 @@ void copy_to_kv_cache_launcher( // 6. for THD, assumes no padding between sequences in new_k and new_v if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) { - printf("-------- 1 %p %d %d %d \n"); //, new_v.data_ptr(), h_kv, d_k, d_v); if (is_non_paged) { reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(k_cache.data_ptr()), diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 3d252cc1e7..63053d7ec9 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -247,15 +247,10 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for ii in range(num_warmup_iters): - print("------ warmup ", ii) hooks = [] for module in func.modules(): hook = module.register_forward_hook(hook_fn) hooks.append(hook) - print(len(args), [x.shape for x in args]) - print(len(args), [x.dtype for x in args]) - #print(args[0][8,0,:4]) - print(kwargs) outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: hook.remove() @@ -270,7 +265,6 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument else: grad_inputs = None del outputs, grad_inputs - print("------ end warmup ------") # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): @@ -432,9 +426,6 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Copy values from new tensors into static tensors for i in range(len_user_args): if isinstance(static_input_surface[i], torch.Tensor) and static_input_surface[i].data_ptr() != inputs[i].data_ptr(): - print(i, inputs[i].shape, static_input_surface[i].shape) - if inputs[i].ndim == 1: - print('input', i, inputs[i]) static_input_surface[i].copy_(inputs[i]) # Replay forward graph diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index a750518edc..6dc17f7f49 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -6,39 +6,10 @@ from collections import OrderedDict from typing import Optional, Dict, List import torch -#from transformer_engine.pytorch.utils import StaticBufferAllocator import transformer_engine_torch as tex +from transformer_engine.pytorch.kv_cache_manager import KVCacheManager from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -class KVCacheManager: - """ - KV cache manager. This should be the base class for custom KV cache managers. - """ - def __init__(self, *args, **kwargs): - """Initialize the cache manager""" - self.cache = {} - def allocate_memory(self, layer_number: int): - """Allocate memory for the cache""" - self.cache[layer_number] = (None, None) - def prepare( - self, - sequences: Dict[List, List], - step_dict: Dict[List, List], - ): - """Prepare for step(). Update sequences with step_dict.""" - return sequences - def step( - self, - layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - qkv_format: str, - ): - """Update the cache with new_k and new_v tokens""" - return *self.cache[layer_number], None - class NonPagedKVCacheManager(KVCacheManager): """ The non-paged KV cache manager. @@ -52,7 +23,6 @@ def __init__( head_dim_k: int, dtype: torch.dtype, head_dim_v: Optional[int] = None, - #is_cuda_graph: bool = False, ): """Initialize the KV cache""" self.max_batch_size = max_batch_size @@ -61,20 +31,10 @@ def __init__( self.head_dim_k = head_dim_k self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - #self.is_cuda_graph = is_cuda_graph - # sequences contained in the kv cache, {seq_id: seq_len} - #self.sequences = OrderedDict() - # KV cache tuple (k_cache, v_cache) self.cache = {} + self.sequences = OrderedDict() self.batch_indices = None -# self._allocator = StaticBufferAllocator() -# -# def alloc(self, size, dtype, device): -# """ -# Allocated the buffer and works correctly with CUDA Graphs. -# """ -# return self._allocator(size, dtype, device) def allocate_memory(self, layer_number): """Allocate memory for the KV cache""" @@ -96,25 +56,18 @@ def allocate_memory(self, layer_number): ) self.cache[layer_number] = (k_cache, v_cache) - #self.batch_indices = self.alloc( self.batch_indices = torch.zeros( self.max_batch_size, dtype=torch.int32, device=torch.cuda.current_device(), - ) + ) - def prepare( + def pre_step( self, - sequences: Dict[List, List], step_dict: Dict[List, List], ): - # TODO: remove - self.sequences = sequences - self.step_dict = step_dict - prev_batch_size = len(self.sequences) - batch_size = len(step_dict) - # Reorder cache + prev_batch_size = len(self.sequences) unfinished_seqs = self.sequences.keys() & step_dict.keys() finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] @@ -124,7 +77,6 @@ def prepare( + finished_indices + list(range(prev_batch_size, self.max_batch_size)) )).to(dtype=torch.int32, device="cpu")) - print('self.batch_indices', self.batch_indices) # Advance unfinished sequences for i in unfinished_seqs: @@ -176,115 +128,21 @@ def step( The value cache tensor containing previous and the current tokens """ k_cache, v_cache = self.cache[layer_number] - #kk=k_cache.clone() - #k_cache1 = kk[self.batch_indices].contiguous() - #k_cache = k_cache[self.batch_indices].contiguous() - #v_cache = v_cache[self.batch_indices].contiguous() step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - #h=self.num_heads #16 - #d=self.head_dim_k #64 - #b=self.max_batch_size #4 - max_ctx_len=1 #k.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + batch_size = self.max_batch_size + ctx_len=1 if qkv_format == "bshd": - max_ctx_len=k.shape[1] + batch_size = k.shape[0] + ctx_len=k.shape[1] if qkv_format == "sbhd": - max_ctx_len=k.shape[0] - max_seq_len=self.max_seqlen #k_cache.shape[1] #64 #128 - max_ctx_tokens=k.shape[0] - max_tokens=k_cache.shape[0]*k_cache.shape[1] - print('kv shapes ', [x.shape for x in [k, v, k_cache, v_cache]]) - #print('step_lens ', step_lens) - #print('seq_lens ', seq_lens) - #print('self.batch_indices ', self.batch_indices) - #print('lensss ', qkv_format, step_lens, seq_lens,max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + batch_size = k.shape[1] + ctx_len=k.shape[0] tex.copy_to_kv_cache( k, v, k_cache, v_cache, self.batch_indices, step_lens, seq_lens, - QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, len(self.step_dict), #self.max_batch_size, - max_ctx_len, max_seq_len, 1, True)#, max_ctx_tokens, max_tokens) - #print('self.batch_indices after', self.batch_indices) - #print('k_cache[0, 0, 0, :4]', k_cache[0, 0, 0, :4]) - #print('k_cache[0, 46:48, 0, :4]', k_cache[0, 46:48, 0, :4]) - #print('k_cache[1, :2, 0, :4]', k_cache[1, :2, 0, :4]) - #print('k_cache[1, 6, 0, :4]', k_cache[1, 6, 0, :4]) - #print('k_cache[0, 17, 0, :4]', k_cache[0, 17, 0, :4]) - #print('k_cache[1, 22, 0, :4]', k_cache[1, 22, 0, :4]) - #print('k_cache[2, 14, 0, :4]', k_cache[2, 14, 0, :4]) - #print('k_cache[3, 12, 0, :4]', k_cache[3, 12, 0, :4]) - #print(k_cache1[0, :2, 0, :4]) - #print(k_cache1[1, :2, 0, :4]) - #print(k_cache[0, :2, 0, :4]) - #print(k_cache[1, :2, 0, :4]) - self.cache[layer_number] = k_cache, v_cache - return k_cache, v_cache, None + QKVFormat[qkv_format], + self.num_heads, self.head_dim_k, self.head_dim_v, + batch_size, ctx_len, self.max_seqlen, 1, True) -# #prev_batch_size = len(self.sequences) -# #batch_size = len(step_dict) -# batch_size = len(self.sequences) -# -# ## Reorder cache -# #unfinished_seqs = self.sequences.keys() & step_dict.keys() -# #finished_seqs = self.sequences.keys() - unfinished_seqs -# #unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] -# #finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] -# #batch_indices = ( -# # unfinished_indices -# # + finished_indices -# # + list(range(prev_batch_size, self.max_batch_size)) -# #) -# new_k_cache = k_cache[self.batch_indices, :] -# new_v_cache = v_cache[self.batch_indices, :] -# new_k_cache = new_k_cache.contiguous() -# new_v_cache = new_v_cache.contiguous() -# -# ## Advance unfinished sequences -# #for i in unfinished_seqs: -# # self.sequences[i] += 1 -# -# ## Remove finished sequences -# #for i in finished_seqs: -# # self.sequences.pop(i) -# -# ## Add new sequences -# #new_seqs = step_dict.keys() - self.sequences.keys() -# #for i in new_seqs: -# # self.sequences[i] = step_dict[i] -# -# # Copy new key/value tokens to cache -# #step_lens = list(step_dict.values()) -# #cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] -# cu_seqlens = cu_seqlens_q -# step_lens = cu_seqlens[1:] - cu_seqlens[:-1] -# seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] -# #print('self.sequences', self.sequences) -# #print('cu_seqlens_q', cu_seqlens_q) -# #print('cu_seqlens_kv', cu_seqlens_kv) -# #print('step_lens', step_lens) -# for i, seq in enumerate(self.sequences.keys()): -# print('kv cm non-paged i', i, 'seq', seq) -# #seq_s = self.sequences[seq] - step_lens[i] -# #seq_e = self.sequences[seq] -# seq_s = seq_lens[i] - step_lens[i] -# seq_e = seq_lens[i] -# if qkv_format == "bshd": -# print('bshd ', [x.device for x in [new_k_cache, step_lens]]) -# new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_lens[i], :, :] -# new_v_cache[i, seq_s:seq_e, :, :] = v[i, : step_lens[i], :, :] -# if qkv_format == "sbhd": -# new_k_cache[i, seq_s:seq_e, :, :] = k[: step_lens[i], i, :, :] -# new_v_cache[i, seq_s:seq_e, :, :] = v[: step_lens[i], i, :, :] -# if qkv_format == "thd": -# new_k_cache[i, seq_s:seq_e, :, :] = k[cu_seqlens[i] : cu_seqlens[i + 1], :, :] -# new_v_cache[i, seq_s:seq_e, :, :] = v[cu_seqlens[i] : cu_seqlens[i + 1], :, :] -# self.cache[layer_number] = (new_k_cache, new_v_cache) -# -# # Return full key/value tensors for attention calculation -# if self.is_cuda_graph: -# # [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] -# return new_k_cache, new_v_cache, None -# -# # [actual_batch_size, max_seqlen_kv, num_heads_kv, head_dim_kv] -# new_k_cache = new_k_cache[:batch_size].contiguous() -# new_v_cache = new_v_cache[:batch_size].contiguous() -# return new_k_cache, new_v_cache, None + return k_cache, v_cache, None diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index 95d9b0c02c..a84dd42de8 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -9,7 +9,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.kv_cache_manager_non_paged import KVCacheManager +from transformer_engine.pytorch.kv_cache_manager import KVCacheManager from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat @@ -67,6 +67,8 @@ def __init__( self.cache = {} # free pages allowed to allocate, [Page(),...] self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) # allocated pages, {seq_id: [page_id,...]} self.allocated_pages = defaultdict(list) # page table, [batch_size, max_pages_per_seq] @@ -91,11 +93,10 @@ def allocate_memory(self, layer_number): device=torch.cuda.current_device(), ) self.cache[layer_number] = (k_cache, v_cache) + self.page_table = torch.zeros( self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" ) - for i in range(self.total_num_pages): - self.free_pages.append(Page(i)) def print_cache(self): """Print KV cache status""" @@ -184,13 +185,10 @@ def deallocate_sequence(self, seq: int): self.free_pages.append(page) self.allocated_pages.pop(seq) - def prepare( + def pre_step( self, - sequences: Dict[List, List], step_dict: Dict[List, List], ): - self.sequences = sequences - self.step_dict = step_dict batch_size = len(step_dict) step_lens = list(step_dict.values()) cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] @@ -251,96 +249,22 @@ def step( v_cache: torch.Tensor The value cache tensor containing previous and the current tokens """ - #batch_size = len(step_dict) - #step_lens = list(step_dict.values()) - #cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] - - ## Remove finished sequences and advance unfinished sequences - #unfinished_seqs = self.sequences.keys() & step_dict.keys() - #finished_seqs = self.sequences.keys() - unfinished_seqs - #for seq in finished_seqs: - # self.sequences.pop(seq) - # self.deallocate_sequence(seq) - #for seq in unfinished_seqs: - # if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: - # self.allocate_page(seq) - # self.sequences[seq] += 1 - - ## Add new sequences - #new_seqs = step_dict.keys() - self.sequences.keys() - #for seq in new_seqs: - # self.sequences[seq] = step_dict[seq] - # self.allocate_sequence(seq, step_dict[seq]) - - ## Copy new key and value tenosrs to the cache - #seqlens = list(self.sequences.values()) - #packed_k = torch.Tensor([]).to(dtype=k.dtype, device=k.device) - #packed_v = torch.Tensor([]).to(dtype=v.dtype, device=v.device) - #for i in range(batch_size): - # if qkv_format == "bshd": - # packed_k = torch.cat([packed_k, k[i, : step_lens[i], :, :]], dim=0) - # packed_v = torch.cat([packed_v, v[i, : step_lens[i], :, :]], dim=0) - # if qkv_format == "sbhd": - # packed_k = torch.cat([packed_k, k[: step_lens[i], i, :, :]], dim=0) - # packed_v = torch.cat([packed_v, v[: step_lens[i], i, :, :]], dim=0) - #if qkv_format == "thd": - # packed_k = k - # packed_v = v k_cache, v_cache = self.cache[layer_number] step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - #h=self.num_heads #16 - #d=self.head_dim_k #64 - #b=self.max_batch_size #4 - max_ctx_len=1 #k.shape[1] if qkv_format in ["bshd", "sbhd"] else 1 #64 + batch_size = self.max_batch_size + ctx_len=1 if qkv_format == "bshd": - max_ctx_len=k.shape[1] + batch_size = k.shape[0] + ctx_len=k.shape[1] if qkv_format == "sbhd": - max_ctx_len=k.shape[0] - max_seq_len=self.max_seqlen #k_cache.shape[1] #64 #128 - max_ctx_tokens=k.shape[0] - max_tokens=k_cache.shape[0]*k_cache.shape[1] - print('kv shapes ', [x.shape for x in [k, v, k_cache, v_cache]]) - #print('step_lens ', step_lens) - #print('seq_lens ', seq_lens) - #print('self.batch_indices ', self.batch_indices) - print('lensss ', max_ctx_len, max_seq_len, max_ctx_tokens, max_tokens) + batch_size = k.shape[1] + ctx_len=k.shape[0] tex.copy_to_kv_cache( k, v, k_cache, v_cache, self.page_table, step_lens, seq_lens, - QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, len(self.step_dict), #self.max_batch_size, - max_ctx_len, max_seq_len, self.max_pages_per_seq, False) - - #for i, seq in enumerate(step_dict.keys()): - # page_list = self.get_page_list(seq) - # start_page, start_token = self.get_page_token_offsets(seqlens[i] - step_lens[i]) - # end_page, end_token = self.get_page_token_offsets(seqlens[i]) - # if start_page == end_page: - # page_id = page_list[start_page] - # k_cache[page_id, start_token:end_token, :, :] = packed_k[ - # cu_seqlens[i] : cu_seqlens[i + 1], :, : - # ] - # v_cache[page_id, start_token:end_token, :, :] = packed_v[ - # cu_seqlens[i] : cu_seqlens[i + 1], :, : - # ] - # else: - # start_offset = 0 - # end_offset = 0 - # for j in range(start_page, end_page + 1): - # if not (j == end_page and end_token == 0): - # start_token_j = start_token if j == start_page else 0 - # end_token_j = end_token if j == end_page else self.page_size - # page_id = page_list[start_page] - # end_offset = end_token_j - start_token_j - # k_cache[page_id, start_token_j:end_token_j, :, :] = packed_k[ - # cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : - # ] - # v_cache[page_id, start_token_j:end_token_j, :, :] = packed_v[ - # cu_seqlens[i] + start_offset : cu_seqlens[i] + end_offset, :, : - # ] - # start_offset = start_offset + end_offset - - ## Get page table - #page_table = self.get_page_table(list(self.sequences.keys())) + QKVFormat[qkv_format], + self.num_heads, self.head_dim_k, self.head_dim_v, + batch_size, ctx_len, self.max_seqlen, self.max_pages_per_seq, False) return k_cache, v_cache, self.page_table diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index ee4df62020..5b1bd82221 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -326,19 +326,3 @@ def round_up_to_nearest_multiple(value, multiple): if multiple == 0: raise ValueError("multiple cannot be zero.") return ((value + multiple - 1) // multiple) * multiple - -class StaticBufferAllocator(torch.nn.Module): - """ - This class is used when we use te.make_graphed_callable(). - CUDA Graphs require all tensors to be static. Neverthless, - torch API make_graphed_callable() takes care of output of torch modules, - and makes them static. Thus by wrapping allocation of memory into - torch.nn.Module, we can greatly simplify our code. - """ - - # pylint: disable=no-self-use - def forward(self, size, dtype, device): - """ - Return buffer of given size, dtype and device. - """ - return torch.zeros(size, dtype=dtype, device=device) From f3975f09429b69492e0bdf1171b6c3559824d2ed Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 15 Feb 2025 13:07:50 -0800 Subject: [PATCH 088/239] WIP: fix non-CG, fused Signed-off-by: Charlene Yang --- transformer_engine/pytorch/kv_cache_manager_non_paged.py | 3 ++- transformer_engine/pytorch/kv_cache_manager_paged.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 6dc17f7f49..0ca8efa722 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -144,5 +144,6 @@ def step( QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, batch_size, ctx_len, self.max_seqlen, 1, True) - + k_cache = k_cache[:batch_size, :ctx_len] + v_cache = v_cache[:batch_size, :ctx_len] return k_cache, v_cache, None diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index a84dd42de8..f0b8b622eb 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -266,5 +266,6 @@ def step( QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, batch_size, ctx_len, self.max_seqlen, self.max_pages_per_seq, False) + page_table = self.page_table[:batch_size] - return k_cache, v_cache, self.page_table + return k_cache, v_cache, page_table From 125548c03d499e62341f284b5246c7bbe439dd78 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 15 Feb 2025 13:13:57 -0800 Subject: [PATCH 089/239] WIP: fix last commit Signed-off-by: Charlene Yang --- transformer_engine/pytorch/kv_cache_manager_non_paged.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 0ca8efa722..c4f8d59a65 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -144,6 +144,6 @@ def step( QKVFormat[qkv_format], self.num_heads, self.head_dim_k, self.head_dim_v, batch_size, ctx_len, self.max_seqlen, 1, True) - k_cache = k_cache[:batch_size, :ctx_len] - v_cache = v_cache[:batch_size, :ctx_len] + k_cache = k_cache[:batch_size] + v_cache = v_cache[:batch_size] return k_cache, v_cache, None From 9bf3204ab3e29992eb88e67d33cef4f8e6e5a003 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 15 Feb 2025 14:13:08 -0800 Subject: [PATCH 090/239] WIP: unfused, non-CG Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 4 +- transformer_engine/pytorch/attention.py | 51 +++++++++++++++++-- .../pytorch/kv_cache_manager_paged.py | 4 -- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 4fa27c907b..e549467d4d 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -198,7 +198,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["UnfusedAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() @@ -231,6 +231,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention")) os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention")) os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) + if backend == "UnfusedAttention" and is_cuda_graph: + pytest.skip("CUDA graph is not supported for UnfusedAttention backend") # create model model = ( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fffb44c1c7..8fd1918fc7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -17,7 +17,6 @@ from dataclasses import dataclass, fields import numpy as np from packaging.version import Version as PkgVersion -from einops import rearrange import torch import torch.nn.functional as F @@ -1025,6 +1024,45 @@ def get_attention_backend( available_backends, ) +@torch.no_grad() +def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) + for i in range(batch_size): + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor( + [False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask_kv = torch.cat( + [ + attention_mask_kv, + torch.Tensor( + [False] * seqlens_kv[i] + + [True] * (max_seqlen_kv - seqlens_kv[i]) + ) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask = ( + attention_mask_q.to(device="cuda"), + attention_mask_kv.to(device="cuda"), + ) + return attention_mask @torch.no_grad() def get_full_mask( @@ -1138,7 +1176,6 @@ def get_full_mask( m = attention_mask.logical_not() actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) - # apply SWA mask mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 @@ -5135,7 +5172,7 @@ def forward( ) if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged( - self.layer_number, qkv_format + self.layer_number, inference_params.input_qkv_format ) if qkv_format == "bshd": @@ -5149,6 +5186,8 @@ def forward( key_layer.shape[0], ) + if "padding" in attn_mask_type and qkv_format in ["bshd", "sbhd"]: + attention_mask = get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( max_seqlen_q, max_seqlen_kv, @@ -7338,8 +7377,8 @@ def forward( # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference - if "padding" not in attn_mask_type: - attn_mask_type = "padding_" + attn_mask_type + #if "padding" not in attn_mask_type: + # attn_mask_type = "padding_" + attn_mask_type if attn_mask_type in ["causal", "padding_causal"]: attn_mask_type = attn_mask_type + "_bottom_right" @@ -7375,8 +7414,10 @@ def forward( value_layer, qkv_format, ) + #print('ssss0 ',query_layer.shape, key_layer.shape, value_layer.shape) #print('cu_seqlens_q',cu_seqlens_q) #print('cu_seqlens_kv',cu_seqlens_kv) + #print('maxxxxx ',max_seqlen_q, max_seqlen_kv) # update cu_seqlens tensors #if inference_params.is_cuda_graph: diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index f0b8b622eb..f22313c972 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -155,10 +155,6 @@ def get_page_table(self, sequences: List[int]): ] ).to(dtype=torch.int32, device="cpu") self.page_table[: self.get_sequence_count()].copy_(page_table) - #if self.is_cuda_graph: - # self.page_table[: self.get_sequence_count()].copy_(page_table) - #else: - # self.page_table = page_table.to(device="cuda") return self.page_table def allocate_page(self, seq: int): From 3060892f1925a3af8e316d6422a68f6fea70cfb7 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 15 Feb 2025 21:42:50 -0800 Subject: [PATCH 091/239] WIP: flash-attn, non-CG Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 13 ++++++++++++- transformer_engine/pytorch/attention.py | 7 ++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index e549467d4d..75e23f89df 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -198,7 +198,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["UnfusedAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FlashAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() @@ -233,6 +233,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") + if backend == "FlashAttention": + config.max_seqlen_q = 256 + config.max_seqlen_kv = 256 # create model model = ( @@ -458,6 +461,7 @@ def gen_data(): for i, seq in enumerate(sim.t_seq_ids): start = (sim.t_total_lens[i] - sim.step_lens[i]).item() end = sim.t_total_lens[i].item() + print('i, seq', i, seq, start, end, sim.step_lens[i], incremental_q.shape, q.shape) incremental_q[i, : sim.step_lens[i], :, :] = q[seq, start:end, :, :] incremental_k[i, : sim.step_lens[i], :, :] = k[seq, start:end, :, :] incremental_v[i, : sim.step_lens[i], :, :] = v[seq, start:end, :, :] @@ -488,6 +492,7 @@ def gen_data(): max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, ) + print("lllllllll ", line_output.shape) # compare results if backend != "FlashAttention": @@ -504,9 +509,15 @@ def gen_data(): } for i, seq in enumerate(sim.t_seq_ids): if qkv_format == "bshd": + print('seqq ', i, seq, sim.t_total_lens[i], sim.step_lens[i]) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[i, :, :4]) + #print(line_output[i, sim.step_lens[i] - 1, :]) torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], line_output[i, :sim.step_lens[i] - 1, :], + #full_output[seq, sim.t_total_lens[i] - 1, :], + #line_output[i, sim.step_lens[i] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8fd1918fc7..f2db716853 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5646,7 +5646,6 @@ def forward( qkv_format = "".join( [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] ) - if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if qkv_format == "sbhd": # For now just 128, will make it more general in the future @@ -5690,6 +5689,8 @@ def forward( if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" + cu_seqlens_q = cu_seqlens_q[:batch_size+1] + cu_seqlens_kv = cu_seqlens_kv[:batch_size+1] if inference_params is None or ( inference_params is not None and not inference_params.is_paged @@ -5825,8 +5826,8 @@ def forward( else: if _flash_attn_2_5_7_plus: fa_optional_forward_kwargs["block_table"] = None - if inference_params is not None: - fa_optional_forward_kwargs["block_table"] = inference_params.page_table + if inference_params is not None and inference_params.is_paged: + fa_optional_forward_kwargs["block_table"] = inference_params.cache_manager.page_table[:batch_size] func = ( flash_attn_varlen_func if not _use_flash_attn_3 From eb9857d6af073ae1d86a9865b30601459855e60a Mon Sep 17 00:00:00 2001 From: hx Date: Tue, 18 Feb 2025 11:58:03 -0800 Subject: [PATCH 092/239] [MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fix FP8 related codes (#1468) * add prob permute; fix fp8tensor Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert unnecessary changes in UT Signed-off-by: Hongxiao Bai * remove unnecessary probs dtype convert Signed-off-by: Hongxiao Bai * keep the output nums if probs is not provided Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine the doc string Signed-off-by: Hongxiao Bai * fix lint Signed-off-by: Hongxiao Bai * use fp32 compute type Signed-off-by: Hongxiao Bai * style fix Signed-off-by: Hongxiao Bai * fix empty input return Signed-off-by: Hongxiao Bai * separate prob related functions out Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen --- docs/api/pytorch.rst | 4 + tests/pytorch/test_permutation.py | 435 +++++++++++++++--- transformer_engine/pytorch/__init__.py | 2 + transformer_engine/pytorch/permutation.py | 221 +++++++-- .../pytorch/triton/permutation.py | 222 ++++++--- 5 files changed, 721 insertions(+), 163 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 6d5fe6761d..4154a18598 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -48,10 +48,14 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute +.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs + .. autoapifunction:: transformer_engine.pytorch.moe_unpermute .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index +.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs + .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 35c6266a3f..0dc183e298 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -10,12 +10,14 @@ from transformer_engine.pytorch import ( moe_permute as te_permute, + moe_permute_with_probs as te_permute_with_probs, moe_unpermute as te_unpermute, moe_sort_chunks_by_index as te_sort_chunks_by_index, + moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, ) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine_torch as tex @@ -198,6 +200,16 @@ def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: raise ValueError(f"Unsuppored dtype ({te_dtype})") +def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False +): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + def _test_permutation_index_map( te_dtype, num_tokens, @@ -265,9 +277,9 @@ def _test_permutation_index_map( permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.dequantize().to(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.dequantize().to(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize().to(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -341,10 +353,10 @@ def _test_permutation_index_map( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.dequantize().to(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize().to(torch.float32) - te_unpermute_output_ = te_unpermute_output.dequantize().to(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize().to(torch.float32) + te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() @@ -388,15 +400,6 @@ def _test_permutation_index_map( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens) @@ -509,19 +512,28 @@ def _test_permutation_mask_map( size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" ) - permute_fwd_input = Float8Tensor.to_float8( - permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - permute_bwd_input = Float8Tensor.to_float8( - permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - unpermute_bwd_input = Float8Tensor.to_float8( - unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _unpermute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) + permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) + unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -541,6 +553,10 @@ def _test_permutation_mask_map( probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) probs = probs / row_sums + if fp8: + probs = probs.to(torch.float16) + else: + probs = probs.to(dtype) probs.requires_grad_(True) ################################################################################################################################### @@ -571,7 +587,7 @@ def _test_permutation_mask_map( te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask" + te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) @@ -596,10 +612,10 @@ def _test_permutation_mask_map( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.from_float8(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) - te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() @@ -644,21 +660,14 @@ def _test_permutation_mask_map( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map) ) t2 = perf_test_cuda_kernel( - lambda: te_permute(te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask") + lambda: te_permute( + te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + ) ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -752,15 +761,21 @@ def _test_moe_chunk_sort( fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - fwd_input = Float8Tensor.to_float8( - fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - bwd_input = Float8Tensor.to_float8( - bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + fwd_input = _fwd_input_quantizer.quantize(fwd_input) + bwd_input = _bwd_input_quantizer.quantize(bwd_input) - pytorch_fwd_input = fwd_input.from_float8(torch.float16) - pytorch_bwd_input = bwd_input.from_float8(torch.float16) + pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16) + pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16) else: pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() @@ -806,8 +821,8 @@ def _test_moe_chunk_sort( tols = dtype_tols(te_dtype) if fp8: - te_output_ = te_output.from_float8(torch.float32) - te_fwd_input_grad = te_fwd_input.grad.from_float8(torch.float32) + te_output_ = te_output.dequantize(dtype=torch.float32) + te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32) else: te_output_ = te_output.float() te_fwd_input_grad = te_fwd_input.grad.float() @@ -834,15 +849,6 @@ def _test_moe_chunk_sort( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs) @@ -873,6 +879,210 @@ def backward_wrapper( print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_mask_map_alongside_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "mask map alongside probs:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + _unpermute_bwd_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input) + unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input) + + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + restore_shape = pytorch_permute_fwd_input.shape + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + if fp8: + probs = probs.to(torch.float16) + else: + probs = probs.to(dtype) + probs.requires_grad_(True) + + split_sizes = [0] * (num_expert * tp_size) + for i in range(num_out_tokens): + idx = random.randint(0, num_expert * tp_size - 1) + split_sizes[idx] += 1 + split_sizes = torch.tensor(split_sizes, dtype=torch.int32) + split_sizes_cuda = split_sizes.to(device="cuda") + + _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32) + sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel() + sorted_idxs_cuda = sorted_idxs.to(device="cuda") + + split_sizes_2 = [split_sizes[i] for i in sorted_idxs.tolist()] + split_sizes_2 = torch.tensor(split_sizes_2, dtype=torch.int32) + split_sizes_2_cuda = split_sizes_2.to(device="cuda") + + sorted_idxs_2 = [0] * (num_expert * tp_size) + for i in range(num_expert * tp_size): + sorted_idxs_2[sorted_idxs[i]] = i + sorted_idxs_2 = torch.tensor(sorted_idxs_2, dtype=torch.int32) + sorted_idxs_2_cuda = sorted_idxs_2.to(device="cuda") + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute_mask_map( + pytorch_permute_fwd_input, routing_map + ) + + pytorch_permute_output = pytorch_sort_chunks_by_index( + pytorch_permute_output, split_sizes, sorted_idxs + ) + + pytorch_permute_output = pytorch_sort_chunks_by_index( + pytorch_permute_output, split_sizes_2, sorted_idxs_2 + ) + + pytorch_unpermute_output = pytorch_unpermute_mask_map( + pytorch_permute_output, sorted_indices, restore_shape, probs, routing_map + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_probs = probs.detach() + te_probs.requires_grad_(True) + print(te_probs.shape) + + te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( + te_permute_fwd_input, + te_probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + print(te_permuted_probs.shape) + + te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( + te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda + ) + + if fp8: + _permute_output_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + te_permute_output = te_permute_output.dequantize(dtype=torch.float32) + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = _permute_output_quantizer.quantize(te_permute_output) + else: + te_permute_output_dtype = te_permute_output.dtype + print(te_permute_output.shape) + print(te_permuted_probs.shape) + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) + + te_permute_output = te_sort_chunks_by_index( + te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda + ) + + te_unpermute_output = te_unpermute( + te_permute_output, + row_id_map, + restore_shape=restore_shape, + map_type="mask", + ) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ############################################################################################### + + tols = dtype_tols(te_dtype) + + if fp8: + # backward of dequantize is in high precision + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + else: + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in fused_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in fused_permute bwd", + **tols, + ) + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols + ) + + def perf_test_cuda_kernel(cuda_kernel_fn): if torch.cuda.is_available(): # create CUDA event @@ -959,6 +1169,63 @@ def test_permutation_mask_map( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_permutation_mask_map_empty_input(te_dtype): + with_probs = True + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + hidden_size=4096, + topK=2, + num_out_tokens=0, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +def test_permutation_mask_map_alongside_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + hidden_size=4096, + topK=2, + num_out_tokens=0, + tp_size=2, + ) + + # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -1023,6 +1290,34 @@ def test_permutation_mask_map_fp8( ) +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +def test_permutation_mask_map_alongside_probs_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + ) + + @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [8, 16]) @@ -1101,6 +1396,20 @@ def test_chunk_permutation( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_chunk_permutation_empty_input(te_dtype): + BENCHMARK = False + + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + tp_size=2, + hidden_size=4096, + BENCHMARK=BENCHMARK, + ) + + def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) @@ -1149,6 +1458,16 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=4, + ) + if __name__ == "__main__": test_permutation_single_case() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 57addca3b9..d424b97f74 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -76,8 +76,10 @@ def _load_library(): from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.permutation import ( moe_permute, + moe_permute_with_probs, moe_unpermute, moe_sort_chunks_by_index, + moe_sort_chunks_by_index_with_probs, ) from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 2e6167a6e0..dd2f60deba 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -261,13 +261,17 @@ def forward( inp: torch.Tensor, routing_map: torch.Tensor, num_out_tokens: int, + probs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): - return inp, torch.tensor([], device=inp.device) + ctx.probs = probs + return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) assert inp.is_cuda, "TransformerEngine needs CUDA." assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." assert inp.size(0) == routing_map.size(0), "Permute not possible" num_tokens, hidden_size = inp.size() @@ -282,48 +286,60 @@ def forward( if fp8: fp8_dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data - output = triton_permutation.permute_with_mask_map( + output, permuted_probs = triton_permutation.permute_with_mask_map( inp, row_id_map, + probs, num_tokens, num_experts, num_out_tokens, hidden_size, ) if fp8: - output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv) + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) ctx.save_for_backward(row_id_map) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size - return output, row_id_map + return output, row_id_map, permuted_probs @staticmethod def backward( ctx, permuted_act_grad: torch.Tensor, _, + permuted_probs_grad: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None + return permuted_act_grad, None, None, ctx.probs act_grad = None + probs_grad = None if ctx.needs_input_grad[0]: (row_id_map,) = ctx.saved_tensors fp8 = isinstance(permuted_act_grad, Float8Tensor) if fp8: fp8_dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data else: fp8_dtype = None - act_grad = triton_permutation.unpermute_with_mask_map( + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, None, + permuted_probs_grad, ctx.num_tokens, ctx.num_experts, ctx.hidden_size, @@ -334,8 +350,12 @@ def backward( data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * ctx.num_experts, + shape=act_grad.shape, + dtype=fake_dtype, ) - return act_grad, None, None + if not ctx.needs_input_grad[3]: + probs_grad = None + return act_grad, None, None, probs_grad class _moe_unpermute_mask_map(torch.autograd.Function): @@ -346,12 +366,12 @@ def forward( ctx, inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor, + merging_probs: torch.Tensor, restore_shape: torch.Size, ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): - ctx.probs = probs + ctx.merging_probs = merging_probs return inp if restore_shape is None: @@ -359,15 +379,9 @@ def forward( num_tokens, hidden_size = restore_shape num_experts = row_id_map.size(0) - with_probs = probs is not None + with_probs = merging_probs is not None if with_probs: - assert probs.is_cuda, "TransformerEngine needs CUDA." - if probs.dtype != torch.float32: - warnings.warn( - f"The data type of the input `probs` of Unpermute is {probs.dtype}! " - "The recommended type is torch.float32." - ) - probs = probs.to(torch.float32) + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -380,13 +394,15 @@ def forward( fp8_scale_inv = inp._scale_inv * num_experts else: fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: fp8_dtype = None - unpermuted_output = triton_permutation.unpermute_with_mask_map( + unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, - probs, + merging_probs, + None, num_tokens, num_experts, hidden_size, @@ -394,11 +410,15 @@ def forward( ) if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=unpermuted_output.shape, + dtype=fake_dtype, ) if with_probs: - ctx.save_for_backward(inp, row_id_map, probs) + ctx.save_for_backward(inp, row_id_map, merging_probs) else: ctx.save_for_backward(row_id_map) ctx.num_experts = num_experts @@ -412,13 +432,13 @@ def forward( def backward(ctx, unpermuted_act_grad): # pylint: disable=missing-function-docstring if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.probs, None + return unpermuted_act_grad, None, ctx.merging_probs, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: if ctx.with_probs: - fwd_input, row_id_map, probs = ctx.saved_tensors + fwd_input, row_id_map, merging_probs = ctx.saved_tensors else: (row_id_map,) = ctx.saved_tensors @@ -426,26 +446,30 @@ def backward(ctx, unpermuted_act_grad): if fp8: fp8_dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv + fake_dtype = unpermuted_act_grad.dtype unpermuted_act_grad = unpermuted_act_grad._data else: fp8_dtype = None if ctx.with_probs: - act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs( - unpermuted_act_grad, - row_id_map, - fwd_input, - probs, - ctx.num_tokens, - ctx.num_experts, - ctx.num_permuted_tokens, - ctx.hidden_size, - fp8_dtype, + act_grad, probs_grad = ( + triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( + unpermuted_act_grad, + row_id_map, + fwd_input, + merging_probs, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + fp8_dtype, + ) ) else: - act_grad = triton_permutation.permute_with_mask_map( + act_grad, _ = triton_permutation.permute_with_mask_map( unpermuted_act_grad, row_id_map, + None, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -454,7 +478,11 @@ def backward(ctx, unpermuted_act_grad): if fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, ) if not ctx.needs_input_grad[2]: @@ -494,20 +522,56 @@ def moe_permute( map_type: str, default = 'mask' Type of the routing map tensor. Options are: 'mask', 'index'. + Refer to `routing_map` for more details. """ if map_type == "index": return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens) + output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) + return output, row_id_map raise ValueError("map_type should be one of 'mask' or 'index'") +def moe_permute_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Permute the tokens and probs based on the routing_map. + Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens, num_experts]. It will be permuted with the tokens + according to the routing_map. + routing_map: torch.Tensor + The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + """ + output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + inp, routing_map, num_out_tokens, probs + ) + return output, permuted_probs, row_id_map + + def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor = None, + merging_probs: torch.Tensor = None, restore_shape: torch.Tensor = None, map_type: str = "mask", + probs: torch.Tensor = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -520,7 +584,7 @@ def moe_unpermute( row_id_map: torch.Tensor The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. - probs: torch.Tensor + merging_probs: torch.Tensor, default = None The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. @@ -529,11 +593,20 @@ def moe_unpermute( map_type: str, default = 'mask' Type of the routing map tensor. Should be the same as the value passed to moe_permute. Options are: 'mask', 'index'. + probs: torch.Tensor, default = None + Renamed to merging_probs. Keep for backward compatibility. """ + if probs is not None: + if merging_probs is not None: + raise ValueError( + "Both merging_probs and probs kwarg are provided. probs is deprecated." + ) + warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.") + merging_probs = probs if map_type == "index": - return _moe_unpermute_index_map.apply(inp, row_id_map, probs) + return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) if map_type == "mask": - return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape) + return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) raise ValueError("map_type should be one of 'mask' or 'index'") @@ -546,14 +619,17 @@ def forward( inp: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, + probs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): - return inp, torch.tensor([], device=inp.device) + return inp, probs assert inp.is_cuda, "TransformerEngine needs CUDA." assert split_sizes.is_cuda, "TransformerEngine needs CUDA." assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) @@ -563,51 +639,69 @@ def forward( if fp8: fp8_dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data - output, row_id_map = triton_permutation.sort_chunks_by_idx( + output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx( inp, split_sizes, sorted_idxs, + probs, num_tokens, hidden_size, num_splits, ) if fp8: - output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv) + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) ctx.save_for_backward(row_id_map) ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size - return output + return output, permuted_probs @staticmethod def backward( ctx, permuted_act_grad: torch.Tensor, + permuted_probs_grad: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None + return permuted_act_grad, None, None, permuted_probs_grad act_grad = None + probs_grad = None if ctx.needs_input_grad[0]: (row_id_map,) = ctx.saved_tensors fp8 = isinstance(permuted_act_grad, Float8Tensor) if fp8: fp8_dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data - act_grad = triton_permutation.sort_chunks_by_map( + act_grad, probs_grad = triton_permutation.sort_chunks_by_map( permuted_act_grad, row_id_map, + permuted_probs_grad, ctx.num_tokens, ctx.hidden_size, ) if fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, ) - return act_grad, None, None + if not ctx.needs_input_grad[3]: + probs_grad = None + return act_grad, None, None, probs_grad def moe_sort_chunks_by_index( @@ -629,4 +723,33 @@ def moe_sort_chunks_by_index( sorted_indices: torch.Tensor Chunk indices used to permute the chunks. """ - return _moe_chunk_sort.apply(inp, split_sizes, sorted_index) + output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None) + return output + + +def moe_sort_chunks_by_index_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + split_sizes: torch.Tensor, + sorted_index: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split and sort the input tensor and probs based on the split_sizes and sorted indices. + The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted + according to the sorted_indices. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens]. It will be permuted with the tokens according to + the split_sizes and sorted_indices. + split_sizes: torch.Tensor + Chunk sizes of the inp tensor along the 0-th dimension. + sorted_indices: torch.Tensor + Chunk indices used to permute the chunks. + """ + output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs) + return output, permuted_probs diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 767362e8c1..4ed92b0c80 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -125,6 +125,8 @@ def _permute_kernel( input_ptr, output_ptr, row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, # sizes num_tokens, num_experts, @@ -134,7 +136,11 @@ def _permute_kernel( stride_input_hidden, stride_output_token, stride_output_hidden, + stride_probs_token, + stride_probs_expert, + stride_permuted_probs_token, # metas + PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -149,12 +155,19 @@ def _permute_kernel( if dst_row != -1: output_off = dst_row * stride_output_token + cur_off * stride_output_hidden tl.store(output_ptr + output_off, inp, mask=mask) + if PERMUTE_PROBS: + if cur_pos == 0: + prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) cur_pos += BLOCK_SIZE def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, + probs: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -162,11 +175,17 @@ def permute_with_mask_map( ): # pylint: disable=missing-function-docstring output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None grid = (num_tokens,) _permute_kernel[grid]( inp, output, row_id_map, + probs, + permuted_probs, num_tokens, num_experts, hidden_size, @@ -174,8 +193,12 @@ def permute_with_mask_map( inp.stride(1), output.stride(0), output.stride(1), + probs.stride(0) if probs is not None else None, + probs.stride(1) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, ) - return output + return output, permuted_probs @triton.autotune( @@ -194,7 +217,9 @@ def _unpermute_kernel( input_ptr, output_ptr, row_id_map_ptr, - probs_ptr, + merging_probs_ptr, + permuted_probs_ptr, + unpermuted_probs_ptr, # sizes num_tokens, num_experts, @@ -204,24 +229,27 @@ def _unpermute_kernel( stride_input_hidden, stride_output_token, stride_output_hidden, - stride_probs_token, - stride_probs_expert, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_permuted_probs_token, + stride_unpermuted_probs_token, + stride_unpermuted_probs_expert, # metas - WITH_PROBS: tl.constexpr, + WITH_MERGING_PROBS: tl.constexpr, + PERMUTE_PROBS: tl.constexpr, FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): if FP8_DTYPE == "e5m2": - compute_type = tl.float16 data_type = tl.float8e5 pytorch_tensor_dtype = tl.uint8 elif FP8_DTYPE == "e4m3": - compute_type = tl.float16 data_type = tl.float8e4nv pytorch_tensor_dtype = tl.uint8 else: - compute_type = input_ptr.dtype.element_ty + data_type = input_ptr.dtype.element_ty assert FP8_DTYPE is None + compute_type = tl.float32 pid = tl.program_id(0) current_start = 0 @@ -235,18 +263,35 @@ def _unpermute_kernel( input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True).to(compute_type) - if WITH_PROBS: - prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + prob_off).to(compute_type) - inp *= prob + inp = inp.to(data_type, bitcast=True) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob accumulator += inp + if PERMUTE_PROBS: + if current_start == 0: + unpermuted_prob_off = ( + pid * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + if src_row != -1: + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + else: + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) if FP8_DTYPE is not None: - if not WITH_PROBS: + if not WITH_MERGING_PROBS: # Directly adding these value may cause overflow for fp8, we scale it here. # The outside fp8_scale_inv is also scaled in the meantime. accumulator /= num_experts accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + else: + accumulator = accumulator.to(data_type) output_off = pid * stride_output_token + current_offset * stride_output_hidden tl.store(output_ptr + output_off, accumulator, mask=mask) current_start += BLOCK_SIZE @@ -255,7 +300,8 @@ def _unpermute_kernel( def unpermute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, - probs: Union[torch.Tensor, None], + merging_probs: Union[torch.Tensor, None], + permuted_probs: Union[torch.Tensor, None], num_tokens: int, num_experts: int, hidden_size: int, @@ -269,12 +315,20 @@ def unpermute_with_mask_map( else: fp8_dtype = None output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if permuted_probs is not None: + unpermuted_probs = torch.empty( + (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda" + ) + else: + unpermuted_probs = None grid = (num_tokens,) _unpermute_kernel[grid]( inp, output, row_id_map, - probs, + merging_probs, + permuted_probs, + unpermuted_probs, num_tokens, num_experts, hidden_size, @@ -282,12 +336,16 @@ def unpermute_with_mask_map( inp.stride(1), output.stride(0), output.stride(1), - probs.stride(0) if probs is not None else None, - probs.stride(1) if probs is not None else None, - WITH_PROBS=probs is not None, + merging_probs.stride(0) if merging_probs is not None else None, + merging_probs.stride(1) if merging_probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + unpermuted_probs.stride(0) if unpermuted_probs is not None else None, + unpermuted_probs.stride(1) if unpermuted_probs is not None else None, + WITH_MERGING_PROBS=merging_probs is not None, + PERMUTE_PROBS=permuted_probs is not None, FP8_DTYPE=fp8_dtype, ) - return output + return output, unpermuted_probs @triton.autotune( @@ -301,13 +359,13 @@ def unpermute_with_mask_map( key=["hidden_size"], ) @triton.jit -def _unpermute_bwd_with_probs_kernel( +def _unpermute_bwd_with_merging_probs_kernel( # pointers fwd_output_grad_ptr, fwd_input_grad_ptr, fwd_input_ptr, - probs_ptr, - probs_grad_ptr, + merging_probs_ptr, + merging_probs_grad_ptr, row_id_map_ptr, # sizes num_tokens, @@ -320,31 +378,30 @@ def _unpermute_bwd_with_probs_kernel( stride_fwd_input_grad_hidden, stride_fwd_input_token, stride_fwd_input_hidden, - stride_probs_token, - stride_probs_expert, - stride_probs_grad_token, - stride_probs_grad_expert, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_merging_probs_grad_token, + stride_merging_probs_grad_expert, # metas FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): if FP8_DTYPE == "e5m2": - compute_type = tl.float16 data_type = tl.float8e5 pytorch_tensor_dtype = tl.uint8 elif FP8_DTYPE == "e4m3": - compute_type = tl.float16 data_type = tl.float8e4nv pytorch_tensor_dtype = tl.uint8 else: - compute_type = fwd_output_grad_ptr.dtype.element_ty + data_type = fwd_output_grad_ptr.dtype.element_ty assert FP8_DTYPE is None + compute_type = tl.float32 pid = tl.program_id(0) for expert_idx in range(num_experts): dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) if dst_row != -1: - prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) current_start = 0 while current_start < hidden_size: current_offset = current_start + tl.arange(0, BLOCK_SIZE) @@ -355,12 +412,16 @@ def _unpermute_bwd_with_probs_kernel( ) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True).to(compute_type) - probs_off = pid * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + probs_off).to(compute_type) - output = inp * prob + inp = inp.to(data_type, bitcast=True) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) if FP8_DTYPE is not None: - output = output.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + output = output.to(pytorch_tensor_dtype, bitcast=True) output_off = ( dst_row * stride_fwd_input_grad_token + current_offset * stride_fwd_input_grad_hidden @@ -373,21 +434,27 @@ def _unpermute_bwd_with_probs_kernel( fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) if FP8_DTYPE is not None: fwd_input = fwd_input.to(data_type, bitcast=True) - prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32) + prob_grad_accum += fwd_input.to(compute_type) * inp current_start += BLOCK_SIZE - probs_grad = tl.sum(prob_grad_accum) - probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert - tl.store(probs_grad_ptr + probs_grad_off, probs_grad) + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) else: - probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert - tl.store(probs_grad_ptr + probs_grad_off, 0.0) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) -def unpermute_with_mask_map_bwd_with_probs( +def unpermute_with_mask_map_bwd_with_merging_probs( fwd_output_grad: torch.Tensor, row_id_map: torch.Tensor, fwd_input: torch.Tensor, - probs: torch.Tensor, + merging_probs: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -404,14 +471,16 @@ def unpermute_with_mask_map_bwd_with_probs( act_grad = torch.empty( (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) - probs_grad = torch.empty((num_tokens, num_experts), dtype=probs.dtype, device="cuda") + merging_probs_grad = torch.empty( + (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" + ) grid = (num_tokens,) - _unpermute_bwd_with_probs_kernel[grid]( + _unpermute_bwd_with_merging_probs_kernel[grid]( fwd_output_grad, act_grad, fwd_input, - probs, - probs_grad, + merging_probs, + merging_probs_grad, row_id_map, num_tokens, num_experts, @@ -422,13 +491,13 @@ def unpermute_with_mask_map_bwd_with_probs( act_grad.stride(1), fwd_input.stride(0), fwd_input.stride(1), - probs.stride(0), - probs.stride(1), - probs_grad.stride(0), - probs_grad.stride(1), + merging_probs.stride(0), + merging_probs.stride(1), + merging_probs_grad.stride(0), + merging_probs_grad.stride(1), fp8_dtype, ) - return act_grad, probs_grad + return act_grad, merging_probs_grad @triton.autotune( @@ -449,6 +518,8 @@ def _sort_chunks_by_idxs_kernel( sorted_indices_ptr, output_ptr, dst_rows_ptr, + probs_ptr, + permuted_probs_ptr, # sizes num_splits, hidden_size, @@ -457,7 +528,10 @@ def _sort_chunks_by_idxs_kernel( stride_input_hidden, stride_output_token, stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, # metas + PERMUTE_PROBS: tl.constexpr, IDX_LOAD_WIDTH: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -508,11 +582,18 @@ def _sort_chunks_by_idxs_kernel( tl.store(output_ptr + output_offsets, inp, mask=mask) current_start += BLOCK_SIZE + if PERMUTE_PROBS: + prob_off = pid * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + def sort_chunks_by_idx( inp: torch.Tensor, split_sizes: torch.Tensor, sorted_indices: torch.Tensor, + probs: torch.Tensor, num_tokens: int, hidden_size: int, num_splits: int, @@ -520,6 +601,10 @@ def sort_chunks_by_idx( # pylint: disable=missing-function-docstring row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None grid = (num_tokens,) _sort_chunks_by_idxs_kernel[grid]( inp, @@ -527,15 +612,20 @@ def sort_chunks_by_idx( sorted_indices, output, row_id_map, + probs, + permuted_probs, num_splits, hidden_size, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), - triton.next_power_of_2(num_splits), + probs.stride(0) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, + IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), ) - return output, row_id_map + return output, row_id_map, permuted_probs @triton.autotune( @@ -554,6 +644,8 @@ def _sort_chunks_by_map( input_ptr, output_ptr, row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, # sizes hidden_size, # strides @@ -561,7 +653,10 @@ def _sort_chunks_by_map( stride_input_hidden, stride_output_token, stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, # metas + PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -575,25 +670,40 @@ def _sort_chunks_by_map( inp = tl.load(input_ptr + input_offsets, mask=mask) tl.store(output_ptr + output_offsets, inp, mask=mask) current_start += BLOCK_SIZE + if PERMUTE_PROBS: + prob_off = dst_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = pid * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) def sort_chunks_by_map( inp: torch.Tensor, row_id_map: torch.Tensor, + probs: torch.Tensor, num_tokens: int, hidden_size: int, ): # pylint: disable=missing-function-docstring output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None grid = (num_tokens,) _sort_chunks_by_map[grid]( inp, output, row_id_map, + probs, + permuted_probs, hidden_size, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), + probs.stride(0) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, ) - return output + return output, permuted_probs From 6673f1658c605f7999e04a43bde1493c4b9ad741 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 18 Feb 2025 15:48:59 -0800 Subject: [PATCH 093/239] [JAX] Flax with compute dtype inferred from input dtype. (#1485) flax module with compute dtype inferred from the inputs Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 9 +- examples/jax/encoder/test_multigpu_encoder.py | 11 +- .../encoder/test_multiprocessing_encoder.py | 5 +- .../jax/encoder/test_single_gpu_encoder.py | 12 +- examples/jax/mnist/test_single_gpu_mnist.py | 10 +- tests/jax/test_distributed_layernorm_mlp.py | 2 - tests/jax/test_layer.py | 8 +- tests/jax/utils.py | 54 ++++-- transformer_engine/jax/flax/module.py | 159 +++++++++--------- transformer_engine/jax/flax/transformer.py | 88 +++++----- 10 files changed, 178 insertions(+), 180 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index f02cc562b5..228105d553 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -56,7 +56,6 @@ def __call__(self, x, mask, disable_dropout=False): self_attn_mask_type="padding", enable_relative_embedding=False, enable_sequence_parallel=self.enable_seq_paral, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) @@ -72,17 +71,15 @@ def __call__(self, x, mask, disable_dropout=False): features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16, )(x) x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16, )(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index eb4a1d0afb..0dab636718 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -51,17 +51,16 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 91186a15c4..6522ed896a 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -57,7 +57,6 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) @@ -67,17 +66,15 @@ def __call__(self, x, mask, disable_dropout=False): features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16, )(x) x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16, )(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index dd1997fe6f..cfbd30b767 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -46,17 +46,16 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -217,6 +216,7 @@ def train_and_evaluate(args): with te.fp8_autocast(enabled=args.use_fp8): encoder = Net(num_embed) + # We use nn.Embed, thus inputs need to be in int inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) var_collect = encoder.init(init_rngs, inputs, masks) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 54ecadeee8..9d8f51cc16 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -36,6 +36,8 @@ def __call__(self, x, disable_dropout=False): nn_Dense = te_flax.DenseGeneral else: nn_Dense = nn.Dense + # dtype is used for param init in TE but computation in Linen.nn + dtype = jnp.float32 if self.use_te else jnp.bfloat16 x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x) x = nn.relu(x) @@ -44,11 +46,13 @@ def __call__(self, x, disable_dropout=False): x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = nn_Dense(features=128, dtype=jnp.bfloat16)(x) + assert x.dtype == jnp.bfloat16 + x = nn_Dense(features=128, dtype=dtype)(x) x = nn.relu(x) x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout) - x = nn_Dense(features=16, dtype=jnp.bfloat16)(x) - x = nn.Dense(features=10, dtype=jnp.bfloat16)(x) + x = nn_Dense(features=16, dtype=dtype)(x) + x = nn_Dense(features=10, dtype=dtype)(x) + assert x.dtype == jnp.bfloat16 return x diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 87a5145c65..77b299e5bf 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -271,7 +271,6 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, - dtype=dtype, use_bias=use_bias, ) params_single = ln_mlp_single.init(init_rngs, x) @@ -289,7 +288,6 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, - dtype=dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index a67335236d..ed15913f38 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -265,8 +265,8 @@ def test_forward( """Test only the forward""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) - ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) - layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + ref_layer_cls = partial(self.reference_layer, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) @@ -288,8 +288,8 @@ def test_backward( """Test forward and backward through value_and_grad()""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) - ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) - layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + ref_layer_cls = partial(self.reference_layer, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 554def2c3f..dba7cb64fc 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -110,7 +110,7 @@ class DotProductAttention(nn.Module): Args: dropout_rate: dropout rate - dtype: the dtype of the computation (default: float32) + dtype: the data type used to allocate the initial parameters (default: float32). float32_logits: bool, if True then compute logits in float32 to avoid numerical issues with bfloat16. """ @@ -195,6 +195,7 @@ def __call__( attn_weights = attn_weights * multiplier attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) + attn_weights = attn_weights.astype(value.dtype) # Take the linear combination of `value`. if self.transpose_batch_sequence: @@ -209,7 +210,7 @@ class DenseGeneral(nn.Module): Attributes: features: tuple with numbers of output features. axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). + dtype: the data type used to allocate the initial parameters (default: float32). kernel_init: initializer function for the weight matrix. use_bias: whether to add a bias to the output (default: False). bias_init: initializer function for the bias vector. @@ -226,7 +227,9 @@ class DenseGeneral(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -239,6 +242,7 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + input_dtype = inputs.dtype features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -248,23 +252,24 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes ) - kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.asarray(kernel, input_dtype) kernel = jnp.reshape(kernel, kernel_shape) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) else: bias = None contract_ind = tuple(range(0, len(axis))) y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + y = y.astype(input_dtype) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) @@ -281,7 +286,7 @@ class MlpBlock(nn.Module): kernel_init: Kernel function, passed to the dense layers. deterministic: Whether the dropout layers should be deterministic. intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. + dtype: the data type used to allocate the initial parameters (default: float32). """ transpose_batch_sequence: bool @@ -296,7 +301,9 @@ class MlpBlock(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -358,6 +365,9 @@ def __call__(self, inputs, deterministic: bool = False): bias_axes="embed", name="wo", )(x) + assert ( + output.dtype == inputs.dtype + ), f"input.dtype={input.dtype}, output.dtype={output.dtype}" return output @@ -429,7 +439,7 @@ class MultiHeadAttention(nn.Module): should be divisible by the number of heads. num_gqa_groups: number of kv attention heads head_dim: dimension of each head. - dtype: the dtype of the computation. + dtype: the data type used to allocate the initial parameters (default: float32). dropout_rate: dropout rate kernel_init: initializer for the kernel of the Dense layers. float32_logits: bool, if True then compute logits in float32 to avoid @@ -453,7 +463,9 @@ class MultiHeadAttention(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -738,6 +750,9 @@ def qkv_init(key, shape, dtype): dtype=self.dtype, name="out", )(x) + assert ( + inputs_q.dtype == inputs_kv.dtype == out.dtype + ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}" return out @@ -763,13 +778,13 @@ def __post_init__(self): def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) + input_dtype = x.dtype features = x.shape[-1] scale = nn_partitioning.param_with_axes( - "scale", self.scale_init, (features,), jnp.float32, axes=("embed",) + "scale", self.scale_init, (features,), self.dtype, axes=("embed",) ) - scale = jnp.asarray(scale, self.dtype) + scale = jnp.asarray(scale, input_dtype) if self.layernorm_type == "layernorm": mean = jnp.mean(x, axis=-1, keepdims=True) @@ -777,9 +792,9 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: y = (x - mean) * lax.rsqrt(var + self.epsilon) bias = nn_partitioning.param_with_axes( - "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",) + "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",) ) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(bias, input_dtype) if not self.zero_centered_gamma: z = y * scale + bias @@ -792,7 +807,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: y = x * lax.rsqrt(mean2 + self.epsilon) z = y * scale - return jnp.asarray(z, self.dtype) + assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}" + return z class RelativePositionBiases(nn.Module): @@ -805,7 +821,7 @@ class RelativePositionBiases(nn.Module): distance bucket. num_heads: Number of heads in the attention layer. Each head will get a different relative position weighting. - dtype: Type of arrays through this module. + dtype: the data type used to allocate the initial parameters (default: float32). embedding_init: initializer for relative embedding table. """ @@ -1087,6 +1103,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): dtype=self.dtype, name="output_layernorm", )(y) + assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}" return y @@ -1293,6 +1310,7 @@ def __call__( name="output_layernorm", )(z) + assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}" return z diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 23bc8d3602..d814c2d4df 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -57,19 +57,15 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga def _create_layernorm_parameters( - layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype + layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype ): - scale = nn_partitioning.param_with_axes( - "scale", scale_init, shape, weight_dtype, axes=scale_axes - ) - scale = scale.astype(dtype) + scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) + scale = scale.astype(input_dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == "layernorm": - bias = nn_partitioning.param_with_axes( - "ln_bias", bias_init, shape, weight_dtype, axes=bias_axes - ) - bias = bias.astype(dtype) + bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) + bias = bias.astype(input_dtype) else: assert layernorm_type == "rmsnorm" bias = None @@ -158,15 +154,15 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp heads = inputs.shape[1] q_seqlen = inputs.shape[2] k_seqlen = inputs.shape[3] - dtype = inputs.dtype + input_dtype = inputs.dtype logits = inputs if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( - self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype + self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype ): if bias is not None: - logits = logits + bias.astype(dtype) + logits = logits + bias.astype(input_dtype) mask_ = mask if self.softmax_type is not SoftmaxType.SCALED_MASKED: @@ -178,25 +174,27 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp if mask is not None: attention_bias = lax.select( mask > 0, - jnp.full(mask.shape, -1e10).astype(dtype), - jnp.full(mask.shape, 0.0).astype(dtype), + jnp.full(mask.shape, -1e10), + jnp.full(mask.shape, 0.0), ) + attention_bias = attention_bias.astype(input_dtype) if bias is not None: attention_bias = _combine_biases(attention_bias, bias) if attention_bias is not None: - logits = logits + attention_bias.astype(dtype) + logits = logits + attention_bias.astype(input_dtype) # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED # and kernel is unavailable, then try on pure scaled softmax custom calls. if is_softmax_kernel_available( - SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype + SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype ): outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) else: outputs = jax_nn.softmax(logits * self.scale_factor) + assert input_dtype == outputs.dtype return outputs @@ -261,9 +259,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -278,7 +274,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): @@ -303,7 +298,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: outputs : jax.numpy.ndarray Output tensors. """ - x = x.astype(self.dtype) + input_dtype = x.dtype features = x.shape[-1] scale, ln_bias = _create_layernorm_parameters( @@ -313,10 +308,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: self.scale_axes, self.bias_init, self.bias_axes, + input_dtype, self.dtype, - self.weight_dtype, ) - return layernorm( + out = layernorm( x, scale, ln_bias, @@ -324,6 +319,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: zero_centered_gamma=self.zero_centered_gamma, epsilon=self.epsilon, ) + assert out.dtype == input_dtype + return out class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods @@ -408,9 +405,7 @@ class DenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -428,13 +423,12 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) super().__post_init__() @@ -454,24 +448,25 @@ def __call__(self, inputs: Array) -> Array: Output tensors. """ + input_dtype = inputs.dtype features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) - inputs = jnp.asarray(inputs, self.dtype) axis = _normalize_axes(axis, inputs.ndim) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - kernel = kernel.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel = kernel.astype(input_dtype) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) else: bias = None @@ -500,11 +495,11 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.weight_dtype, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) - lora_a_kernel = lora_a_kernel.astype(self.dtype) + lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) @@ -512,10 +507,10 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=lora_b_kernel_axes, ) - lora_b_kernel = lora_b_kernel.astype(self.dtype) + lora_b_kernel = lora_b_kernel.astype(input_dtype) y += _apply_low_rank_adaptation( inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha @@ -524,6 +519,8 @@ def __call__(self, inputs: Array) -> Array: if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape y += jnp.reshape(bias, bias_shape) + + assert y.dtype == input_dtype return y @@ -606,9 +603,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -638,7 +633,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None @@ -650,7 +644,7 @@ def __post_init__(self): 1.0, "fan_in", "truncated_normal", - dtype=self.weight_dtype, + dtype=self.dtype, ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -677,6 +671,7 @@ def __call__(self, inputs: Array) -> Array: If :attr:`return_layernorm_output=False`, then this would be None. """ + input_dtype = inputs.dtype ln_output = None fuse_layernorm = ( @@ -684,7 +679,6 @@ def __call__(self, inputs: Array) -> Array: and not self.return_layernorm_output and self.enable_layernorm ) - inputs = inputs.astype(self.dtype) if self.enable_layernorm: inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) @@ -699,8 +693,8 @@ def __call__(self, inputs: Array) -> Array: self.scale_axes, self.ln_bias_init, self.ln_bias_axes, + input_dtype, self.dtype, - self.weight_dtype, ) if not fuse_layernorm: @@ -730,9 +724,10 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - kernel = kernel.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -775,11 +770,11 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.weight_dtype, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) - lora_a_kernel = lora_a_kernel.astype(self.dtype) + lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) @@ -787,10 +782,10 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=lora_b_kernel_axes, ) - lora_b_kernel = lora_b_kernel.astype(self.dtype) + lora_b_kernel = lora_b_kernel.astype(input_dtype) z += _apply_low_rank_adaptation( y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha @@ -799,9 +794,9 @@ def __call__(self, inputs: Array) -> Array: bias = None if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) if bias is not None: bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape @@ -810,6 +805,7 @@ def __call__(self, inputs: Array) -> Array: if self.depth_scaling is not None: z = z / self.depth_scaling + assert z.dtype == input_dtype return z, ln_output # dense_output, layer_norm_output @@ -915,9 +911,7 @@ class LayerNormMLP(TransformerEngineBase): Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -950,7 +944,6 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None @@ -959,7 +952,7 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -988,6 +981,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: If :attr:`return_layernorm_output=False`, then this would be None. """ + input_dtype = inputs.dtype ln_output = None fuse_layernorm = ( @@ -996,8 +990,6 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: and self.enable_layernorm ) - inputs = inputs.astype(self.dtype) - gated_act_pool = [ ("gelu", "linear"), ("silu", "linear"), @@ -1035,8 +1027,8 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: self.scale_axes, self.ln_bias_init, self.ln_bias_axes, + input_dtype, self.dtype, - self.weight_dtype, ) if not fuse_layernorm: @@ -1083,11 +1075,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, kernel_1_each_shape, - self.weight_dtype, + self.dtype, axes=self.kernel_axes_1, ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) - kernel_1 = kernel_1.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1096,11 +1089,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_kernel", self.kernel_init, kernel_2_param_shape, - self.weight_dtype, + self.dtype, axes=self.kernel_axes_2, ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) - kernel_2 = kernel_2.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) ffn1_ckpt_name = "ffn1" @@ -1115,20 +1109,20 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_bias", self.bias_init, bias_1_shape, - self.weight_dtype, + self.dtype, axes=self.bias_axes_1, ) - bias_1 = bias_1.astype(self.dtype) + bias_1 = bias_1.astype(input_dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( "wo_bias", self.bias_init, bias_2_shape, - self.weight_dtype, + self.dtype, axes=self.bias_axes_2, ) - bias_2 = bias_2.astype(self.dtype) + bias_2 = bias_2.astype(input_dtype) else: bias_1 = None bias_2 = None @@ -1195,11 +1189,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, wi_lora_a_kernel_init_each_shape, - self.weight_dtype, + self.dtype, axes=wi_lora_a_kernel_axes, ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) - wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) + wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_b_kernel_shape = ( num_activations, @@ -1211,10 +1205,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_lora_b_kernel", nn.initializers.zeros, wi_lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=wi_lora_b_kernel_axes, ) - wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) + wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype) x += _apply_low_rank_adaptation( y, @@ -1231,11 +1225,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_bias", self.bias_init, intermediate_dim, - self.weight_dtype, + self.dtype, axes=self.bias_axes_1, ) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape - bias_1 = bias_1.astype(self.dtype) + bias_1 = bias_1.astype(input_dtype) x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) @@ -1250,7 +1244,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): z = functools.reduce(operator.mul, activations) # Remove act axis z = jnp.reshape(z, (*z.shape[:-2], -1)) - z = z.astype(self.dtype) + z = z.astype(input_dtype) z = nn.Dropout( rate=self.intermediate_dropout_rate, @@ -1259,7 +1253,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): )(z, deterministic=deterministic) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) - z = z.astype(self.dtype) + z = z.astype(input_dtype) # DenseGeneral 2 out = type_safe_dot_general( @@ -1273,10 +1267,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_a_kernel", self.kernel_init, wo_lora_a_kernel_shape, - self.weight_dtype, + self.dtype, axes=wo_lora_a_kernel_axes, ) - wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) + wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype) wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) @@ -1284,10 +1278,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_b_kernel", nn.initializers.zeros, wo_lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=wo_lora_b_kernel_axes, ) - wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) + wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype) out += _apply_low_rank_adaptation( z, @@ -1304,12 +1298,13 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_bias", self.bias_init, (hidden_size,), - self.weight_dtype, + self.dtype, axes=self.bias_axes_2, ) - bias_2 = bias_2.astype(self.dtype) + bias_2 = bias_2.astype(input_dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out = checkpoint_name(out, ffn2_ckpt_name) + assert out.dtype == input_dtype return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 100557404b..69fb74ba31 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -115,7 +115,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 float32_logits: bool = False scale_factor: Optional[float] = None transpose_batch_sequence: bool = True @@ -143,6 +142,8 @@ def __call__( assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match." assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." + input_dtype = query.dtype + if self.scale_factor is None: scale_factor = 1.0 / sqrt(query.shape[-1]) else: @@ -150,8 +151,8 @@ def __call__( del self.scale_factor if self.float32_logits: - query = query.astype(self.dtype) - key = key.astype(self.dtype) + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) h_q, h_kv = query.shape[-2], key.shape[-2] # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # Therefore, we have to maintain two code paths. @@ -234,7 +235,7 @@ def convert_to_softmax_type(attn_mask_type, mask): attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( attn_weights, mask, bias - ).astype(self.dtype) + ).astype(input_dtype) if is_gqa: attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) @@ -244,9 +245,12 @@ def convert_to_softmax_type(attn_mask_type, mask): dropout_shape = list(attn_weights.shape) # TODO(rewang): add attention dropout broadcast dimension arguments for users keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) + multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype) attn_weights = attn_weights * multiplier + assert ( + attn_weights.dtype == input_dtype + ), f"output={attn_weights.dtype}, input={input_dtype}" if self.transpose_batch_sequence: if is_gqa: return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) @@ -254,6 +258,7 @@ def convert_to_softmax_type(attn_mask_type, mask): if is_gqa: return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) + return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) @@ -262,7 +267,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD scale_factor: Optional[float] = None transpose_batch_sequence: bool = False @@ -372,6 +376,7 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) + assert x.dtype == query.dtype return x @@ -492,9 +497,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. """ head_dim: int @@ -504,7 +507,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods attn_mask_type: AttnMaskType = "causal" attn_bias_type: AttnBiasType = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 dropout_rng_name: str = "dropout" float32_logits: bool = False qkv_layout: str = "bshd_bshd_bshd" @@ -552,6 +554,7 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ + input_dtype = query.dtype if mask is not None: if sequence_descriptor is not None: @@ -642,7 +645,6 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, - weight_dtype=self.weight_dtype, float32_logits=self.float32_logits, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, @@ -662,7 +664,6 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, - weight_dtype=self.weight_dtype, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, @@ -679,7 +680,7 @@ def __call__( dropout_rng=dropout_rng, deterministic=deterministic, ) - + assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}" return x @@ -720,10 +721,10 @@ def alternate_impl(): sin, cos = generate_sin_cos(time_scales) x1, x2 = jnp.split(x, 2, axis=-1) - part_1 = (x1 * cos - x2 * sin).astype(x.dtype) - part_2 = (x2 * cos + x1 * sin).astype(x.dtype) + part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype) + part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype) - output = jnp.concatenate([part_1, part_2], axis=-1) + output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype) return output def consecutive_impl(): @@ -928,8 +929,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. fuse_qkv_params: bool, default = True If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for @@ -975,7 +974,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False @@ -1026,7 +1024,7 @@ def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.weight_dtype + 1.0, "fan_in", "normal", dtype=self.dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1071,6 +1069,11 @@ def __call__( Output tensors. """ + assert ( + inputs_q.dtype == inputs_kv.dtype + ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}" + input_dtype = inputs_q.dtype + def query_init(*args): depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) @@ -1154,7 +1157,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): dot_input_axes=inputs_logical_axes_no_sp, name="qkv", dtype=self.dtype, - weight_dtype=self.weight_dtype, )(inputs_q) qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_layout = QKVLayout.BS3HD @@ -1178,7 +1180,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1203,7 +1204,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, name="kv", dtype=self.dtype, - weight_dtype=self.weight_dtype, )(inputs_kv) kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") qkv_layout = QKVLayout.BSHD_BS2HD @@ -1221,7 +1221,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, ) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, @@ -1242,7 +1241,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1253,9 +1251,11 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): assert ln_out is not None inputs_kv = ln_out + query = query.astype(input_dtype) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) - key = key.astype(self.dtype) + key = key.astype(input_dtype) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) + value = value.astype(input_dtype) query = checkpoint_name(query, "query_proj") key = checkpoint_name(key, "key_proj") value = checkpoint_name(value, "value_proj") @@ -1380,7 +1380,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): attn_bias_type=self.attn_bias_type, attention_dropout=self.attention_dropout, dtype=self.dtype, - weight_dtype=self.weight_dtype, dropout_rng_name=self.dropout_rng_name, float32_logits=self.float32_logits, qkv_layout=qkv_layout.name, @@ -1406,11 +1405,13 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, name="out", )(x) out = checkpoint_name(out, "out_proj") + assert ( + inputs_q.dtype == out.dtype + ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}" return out, ln_out @@ -1435,9 +1436,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. """ num_buckets: int @@ -1446,7 +1445,6 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 @nn.compact def __call__(self, q_seqlen, k_seqlen, bidirectional=True): @@ -1499,7 +1497,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): "rel_embedding", self.embedding_init, (self.num_attention_heads, self.num_buckets), - self.weight_dtype, + self.dtype, axes=self.embedding_axes, ) @@ -1672,9 +1670,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. drop_path: float, default = 0.0 When > 0.0, applies stochastic depth per sample in the main path of the residual block. @@ -1727,7 +1723,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True transpose_batch_sequence: bool = False @@ -1739,11 +1734,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: self.mha_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.weight_dtype + 1.0, "fan_in", "normal", dtype=self.dtype ) if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1793,9 +1788,7 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ - - inputs = inputs.astype(self.dtype) - + input_dtype = inputs.dtype assert ( self.layer_type in TransformerLayerType ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." @@ -1833,8 +1826,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, - weight_dtype=self.weight_dtype, - embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), + embedding_init=nn.initializers.variance_scaling( + 1.0, "fan_avg", "uniform", dtype=self.dtype + ), name="relpos_bias", ) else: @@ -1867,7 +1861,6 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): x, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, - weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -1946,7 +1939,6 @@ def hidden_dropout(x, deterministic): y, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, - weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -2012,7 +2004,6 @@ def hidden_dropout(x, deterministic): intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, dtype=self.dtype, - weight_dtype=self.weight_dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_init=self.mlp_kernel_init, @@ -2062,8 +2053,7 @@ def hidden_dropout(x, deterministic): bias_axes=(W_NO_SHARD_AXES,), transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, - weight_dtype=self.weight_dtype, name="output_layernorm", )(z) - + assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" return z From 978f1d72963f161654188b9ec3658e99d1e22dba Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Wed, 19 Feb 2025 10:49:53 +0800 Subject: [PATCH 094/239] Fix issues for MCore DDP. (#1474) * Fix issues for MCore DDP. Signed-off-by: Dennis Liu * Remove force data release for CPU offloading. Signed-off-by: Dennis Liu * Add preserved attributeds. Signed-off-by: Dennis Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add main_grad to prevserved attributes. Signed-off-by: Dennis Liu * Change prepare_for_saving to original tensor and add .data to CPU hook. Signed-off-by: Dennis Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update. Signed-off-by: Dennis Liu * Fix for LayernormLinear in FP8. Signed-off-by: Dennis Liu --------- Signed-off-by: Dennis Liu Co-authored-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/cpu_offload.py | 4 +++- .../pytorch/module/layernorm_linear.py | 19 ++++++++++------ transformer_engine/pytorch/module/linear.py | 7 +++++- .../pytorch/tensor/quantized_tensor.py | 22 +++++-------------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 33de562a89..c47130fe78 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -137,7 +137,9 @@ def __init__( super().__init__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + retrieve_identifier = self.offload_handler.tensor_push( + tensor.data, **self.handler_extra_kwargs + ) return retrieve_identifier def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d7a7f20dc4..01bda64101 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -441,7 +441,7 @@ def backward( ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight, - _, + origin_weight, bias, ln_weight, ln_out, @@ -722,17 +722,22 @@ def backward( if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): - weight.grad_added_to_main_grad = True - if getattr(weight, "zero_out_wgrad", False): + if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + origin_weight.grad_added_to_main_grad = True + if getattr(origin_weight, "zero_out_wgrad", False): wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, + origin_weight.main_grad.shape, + dtype=origin_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: - wgrad = None + wgrad = torch.empty( + origin_weight.main_grad.shape, + dtype=origin_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: wgrad = None else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 415cc7d9a9..e51513630f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -606,7 +606,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], requires_grad=False, ) else: - wgrad = None + wgrad = torch.empty( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) elif ctx.fuse_wgrad_accumulation: wgrad = None else: diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 707382696d..ef21412ca7 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -28,7 +28,7 @@ def prepare_for_saving( tensor_list.append(None) tensor_objects_list.append(None) elif type(tensor) in (torch.Tensor, torch.nn.Parameter): - tensor_list.append(tensor.data) + tensor_list.append(tensor) tensor_objects_list.append(None) else: t, t_obj = tensor.prepare_for_saving() @@ -116,10 +116,7 @@ def update_quantized( """Quantize tensor in-place""" def quantize( - self, - tensor: torch.Tensor, - *, - out: Optional[QuantizedTensor] = None, + self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None ) -> QuantizedTensor: """Quantize tensor""" if out is not None: @@ -159,10 +156,7 @@ def calibrate(self, tensor: torch.Tensor) -> None: """ def set_usage( - self, - *, - rowwise: Optional[bool] = None, - columnwise: Optional[bool] = None, + self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None ) -> None: """Set how the quantized tensor is expected to be used @@ -194,8 +188,7 @@ def forward( @staticmethod def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, + _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision @@ -212,9 +205,7 @@ class _IdentityFunc(torch.autograd.Function): @staticmethod def forward( - ctx, - tensor: QuantizedTensor, - init_kwargs: Optional[Dict[str, Any]] = None, + ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None ) -> QuantizedTensor: # pylint: disable=missing-function-docstring @@ -408,8 +399,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return torch._C._disabled_torch_function_impl(func, types, args, kwargs) def contiguous( - self, - memory_format: torch.memory_format = torch.contiguous_format, + self, memory_format: torch.memory_format = torch.contiguous_format ) -> QuantizedTensor: # pylint: disable=missing-function-docstring raise NotImplementedError( From 11f15fdffa98666e7b48588ae444852e61a06291 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Tue, 18 Feb 2025 18:56:23 -0800 Subject: [PATCH 095/239] WIP: flash_attn_with_kvcache Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 40 ++- transformer_engine/pytorch/attention.py | 309 +++++++++--------- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cu | 21 +- .../pytorch/kv_cache_manager_paged.py | 8 + 5 files changed, 201 insertions(+), 179 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 75e23f89df..35fa8984dc 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -77,7 +77,7 @@ def __init__( self.context_lens = torch.randint( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) - #self.context_lens = 10 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + #self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -88,7 +88,7 @@ def __init__( self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( dtype=torch.int32, device="cpu" ) - #self.gen_lens = 5 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + #self.gen_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate arrival times in Poisson distribution if poisson_rate is None: @@ -198,7 +198,7 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FlashAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() @@ -285,7 +285,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): max_batch_size = config.batch_size page_size = None total_num_pages = None - if is_paged: + if is_paged: page_size = 256 if backend == "FlashAttention" else 16 config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) @@ -385,9 +385,6 @@ def gen_data(): # increase counter for gen_lens = [3, 5, 1, 1] max_tokens = config.batch_size * config.max_ctx_len while True: - if inference_params.is_paged: - inference_params.cache_manager.print_cache() - # prepare batch for the current step dynamic_fill = True #inference_params.is_paged sim.step(dynamic_fill=dynamic_fill) @@ -461,7 +458,6 @@ def gen_data(): for i, seq in enumerate(sim.t_seq_ids): start = (sim.t_total_lens[i] - sim.step_lens[i]).item() end = sim.t_total_lens[i].item() - print('i, seq', i, seq, start, end, sim.step_lens[i], incremental_q.shape, q.shape) incremental_q[i, : sim.step_lens[i], :, :] = q[seq, start:end, :, :] incremental_k[i, : sim.step_lens[i], :, :] = k[seq, start:end, :, :] incremental_v[i, : sim.step_lens[i], :, :] = v[seq, start:end, :, :] @@ -480,6 +476,8 @@ def gen_data(): zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) ) inference_params.pre_step(step_dict) + if inference_params.is_paged: + inference_params.cache_manager.print_cache() line_output = model( incremental_q, incremental_k, @@ -492,7 +490,6 @@ def gen_data(): max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, ) - print("lllllllll ", line_output.shape) # compare results if backend != "FlashAttention": @@ -508,30 +505,31 @@ def gen_data(): torch.bfloat16: 1e-2, } for i, seq in enumerate(sim.t_seq_ids): + token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 if qkv_format == "bshd": - print('seqq ', i, seq, sim.t_total_lens[i], sim.step_lens[i]) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[i, :, :4]) - #print(line_output[i, sim.step_lens[i] - 1, :]) torch.testing.assert_close( - full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - line_output[i, :sim.step_lens[i] - 1, :], - #full_output[seq, sim.t_total_lens[i] - 1, :], - #line_output[i, sim.step_lens[i] - 1, :], + #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + #line_output[:sim.step_lens[i] - 1, i, :], + full_output[seq, sim.t_total_lens[i] - 1, :], + line_output[i, token_index, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "sbhd": torch.testing.assert_close( - full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - line_output[:sim.step_lens[i] - 1, i, :], + #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + #line_output[:sim.step_lens[i] - 1, i, :], + full_output[seq, sim.t_total_lens[i] - 1, :], + line_output[token_index, i, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "thd": torch.testing.assert_close( - full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], + #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + #line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], + full_output[seq, sim.t_total_lens[i] - 1, :], + line_output[cu_seqlens_q[i + 1] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f2db716853..944930738e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -119,12 +119,14 @@ def _get_supported_versions(version_min, version_max): _flash_attn_max_version = PkgVersion("2.7.3") _flash_attn_2_plus = False _flash_attn_2_1_plus = False +_flash_attn_2_2_plus = False _flash_attn_2_3_plus = False _flash_attn_2_4_plus = False _flash_attn_2_4_1_plus = False +_flash_attn_2_5_plus = False _flash_attn_2_5_7_plus = False -_flash_attn_2_6_0_plus = False -_flash_attn_2_7_0_plus = False +_flash_attn_2_6_plus = False +_flash_attn_2_7_plus = False flash_attn_cuda_bwd = None flash_attn_func = None @@ -163,12 +165,16 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") + _flash_attn_2_2_plus = _flash_attn_version >= PkgVersion("2.2") + if _flash_attn_2_2_plus: + from flash_attn.flash_attn_interface import flash_attn_with_kvcache _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") + _flash_attn_2_5_plus = _flash_attn_version >= PkgVersion("2.5.0") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") - _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") - _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0") + _flash_attn_2_6_plus = _flash_attn_version >= PkgVersion("2.6.0") + _flash_attn_2_7_plus = _flash_attn_version >= PkgVersion("2.7.0") elif ( torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN ): @@ -212,6 +218,9 @@ def _get_supported_versions(version_min, version_max): from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) + from flashattn_hopper.flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 from flashattn_hopper.flash_attn_interface import ( @@ -505,7 +514,7 @@ def get_attention_backend( # backend | non-paged/paged | precision # --------------------------------------------------------------------------------- # FlashAttention | non-paged/paged | FP16/BF16 - # FusedAttention | non-paged/paged | FP16/BF16 + # FusedAttention | non-paged/paged | FP16/BF16 (non-paged/paged), FP8 (non-paged) # UnfusedDotProductAttention | non-paged/paged | FP32/FP16/BF16 if inference_params is not None: if context_parallel: @@ -516,17 +525,26 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug("Disabling all backends as FP8 KV caching is not yet implemented") - use_flash_attention = False - use_fused_attention = False - use_unfused_attention = False + if use_flash_attention: + use_flash_attention = False + logger.debug("Disabling FlashAttention for FP8 KV caching") + if use_fused_attention and inference_params.is_paged: + use_fused_attention = False + logger.debug("Disabling FusedAttention as it does not support paged attention in FP8") + if use_unfused_attention: + use_unfused_attention = False + logger.debug("Disabling UnfusedAttention as it does not support FP8 attention") + else: + if use_flash_attention and not _flash_attn_2_2_plus and not _use_flash_attn_3: + use_flash_attention = False + logger.debug("Disabling FlashAttention as KV caching requires flash-attn 2.2+, or 3.0 (Hopper only)") if inference_params.is_paged: if use_fused_attention and cudnn_version < (9, 5, 0): - logger.debug("Disabling FusedAttention as paged KV caching requires cuDNN 9.5+") + logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+") use_fused_attention = False - if use_flash_attention and not _use_flash_attn_3 and not _flash_attn_2_5_7_plus: + if use_flash_attention and not _use_flash_attn_3 and not _flash_attn_2_5_plus: logger.debug( - "Disabling FlashAttention as paged KV caching requires flash-attn 2.5.7+ or v3" + "Disabling FlashAttention as paged attention requires flash-attn 2.5+, or 3.0 (Hopper only)" ) use_flash_attention = False @@ -2014,7 +2032,7 @@ def forward( if use_fused_attention: softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) else: - softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3 + softmax_lse_in_packed_format = _flash_attn_2_6_plus or _use_flash_attn_3 flash_attn_fwd = None if not use_fused_attention: @@ -2032,16 +2050,16 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3: + if (_flash_attn_2_3_plus and not _flash_attn_2_7_plus) or _use_flash_attn_3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = 0 if causal else -1 if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_0_plus: + if _flash_attn_2_6_plus: fa_forward_kwargs["softcap"] = 0.0 # Flash Attn inputs @@ -2206,7 +2224,7 @@ def forward( causal=True, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not _flash_attn_2_7_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2317,10 +2335,10 @@ def forward( max_seqlen_kv // 2, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_plus ): fa_forward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( @@ -2339,7 +2357,7 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not _flash_attn_2_7_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2459,10 +2477,10 @@ def forward( max_seqlen_kv, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_plus ): fa_forward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( @@ -2481,7 +2499,7 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not _flash_attn_2_7_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2592,7 +2610,7 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not _flash_attn_2_7_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2951,7 +2969,7 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_0_plus: + if _flash_attn_2_6_plus: fa_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): @@ -3084,10 +3102,10 @@ def backward(ctx, dout): ctx.max_seqlen_kv, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_plus ): fa_backward_kwargs["window_size"] = (-1, 0) - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 if not _use_flash_attn_3: @@ -3199,10 +3217,10 @@ def backward(ctx, dout): ctx.max_seqlen_kv // 2, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_plus ): fa_backward_kwargs["window_size"] = (-1, -1) - if _flash_attn_2_7_0_plus: + if _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: @@ -3317,10 +3335,10 @@ def backward(ctx, dout): ctx.max_seqlen_kv, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_plus ): fa_backward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: @@ -3410,9 +3428,9 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_backward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: @@ -3755,7 +3773,7 @@ def forward( fa_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_0_plus: + if _flash_attn_2_6_plus: fa_forward_kwargs["softcap"] = 0.0 assert qkv_format != "thd", f"{qkv_format} format is not supported!" @@ -3866,10 +3884,10 @@ def forward( max_seqlen_kv_, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_plus ): fa_forward_kwargs["window_size"] = window_size_per_step[i] - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( @@ -3880,7 +3898,7 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not _flash_attn_2_7_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -4002,7 +4020,7 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_0_plus: + if _flash_attn_2_6_plus: fa_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): @@ -4061,9 +4079,9 @@ def backward(ctx, dout): ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: + if _flash_attn_2_3_plus and not _flash_attn_2_7_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] - if _flash_attn_2_7_0_plus: + if _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( @@ -4214,16 +4232,16 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_forward_kwargs["window_size"] = window_size - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_forward_kwargs["window_size_left"] = window_size[0] fa_forward_kwargs["window_size_right"] = window_size[1] if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_0_plus: + if _flash_attn_2_6_plus: fa_forward_kwargs["softcap"] = 0.0 assert ( @@ -4336,7 +4354,7 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not _flash_attn_2_7_plus: out, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not _use_flash_attn_3 else None else: @@ -4512,16 +4530,16 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_backward_kwargs["window_size"] = ctx.window_size - elif _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = ctx.window_size[0] fa_backward_kwargs["window_size_right"] = ctx.window_size[1] if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_0_plus: + if _flash_attn_2_6_plus: fa_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: @@ -5692,9 +5710,7 @@ def forward( cu_seqlens_q = cu_seqlens_q[:batch_size+1] cu_seqlens_kv = cu_seqlens_kv[:batch_size+1] - if inference_params is None or ( - inference_params is not None and not inference_params.is_paged - ): + if inference_params is None: # [b * s, h, d] query_layer, key_layer, value_layer = [ x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) @@ -5732,21 +5748,6 @@ def forward( key_layer, value_layer = PackTensors.apply( indices_kv, key_layer, value_layer ) - else: - # [b * s, h, d] - query_layer = query_layer.reshape( - query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:] - ) - if cu_seqlens_q is None: - assert ( - attention_mask is not None - ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices( - attention_mask if self.attention_type == "self" else attention_mask[0] - ) - else: - indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - query_layer = PackTensors.apply(indices_q, query_layer) else: # Cumulative sequence lengths for unpadded data if cu_seqlens_q is None: @@ -5821,101 +5822,112 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] - if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: - func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 - else: - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None - if inference_params is not None and inference_params.is_paged: - fa_optional_forward_kwargs["block_table"] = inference_params.cache_manager.page_table[:batch_size] - func = ( - flash_attn_varlen_func - if not _use_flash_attn_3 - else flash_attn_varlen_func_v3 + if inference_params is not None: + func = flash_attn_with_kvcache + fa_optional_forward_kwargs_kvcache = {} + fa_optional_forward_kwargs_kvcache["cache_seqlens"] = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_optional_forward_kwargs_kvcache["softmax_scale"] = self.softmax_scale + fa_optional_forward_kwargs_kvcache["causal"] = "causal" in attn_mask_type + if inference_params.is_paged: + fa_optional_forward_kwargs_kvcache["block_table"] = inference_params.cache_manager.page_table[:batch_size] + output = func( + query_layer, + key_layer, + value_layer, + **fa_optional_forward_kwargs_kvcache, ) - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) - fa_optional_forward_args_thd.append(max_seqlen_q) - fa_optional_forward_args_thd.append(max_seqlen_kv) - if _use_flash_attn_3: - fa_3_optional_forward_kwargs = {} - fa_3_optional_forward_kwargs["window_size"] = window_size - fa_3_optional_forward_kwargs["deterministic"] = self.deterministic - if fp8: - QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) - torch_orig_dtype = query_layer.dtype - - def convert_to_torch_float8(tensor, dtype): - out = torch.Tensor().to(device=tensor.device, dtype=dtype) - out.set_( - tensor._data.untyped_storage(), - tensor._data.storage_offset(), - tensor._data.shape, - tensor._data.stride(), + else: + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: + func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 + else: + func = ( + flash_attn_varlen_func + if not _use_flash_attn_3 + else flash_attn_varlen_func_v3 + ) + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append(max_seqlen_q) + fa_optional_forward_args_thd.append(max_seqlen_kv) + if _use_flash_attn_3: + fa_3_optional_forward_kwargs = {} + fa_3_optional_forward_kwargs["window_size"] = window_size + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic + if fp8: + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] + torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + torch_orig_dtype = query_layer.dtype + + def convert_to_torch_float8(tensor, dtype): + out = torch.Tensor().to(device=tensor.device, dtype=dtype) + out.set_( + tensor._data.untyped_storage(), + tensor._data.storage_offset(), + tensor._data.shape, + tensor._data.stride(), + ) + return out + + # "fp8_mha" decides outputs in fp8, while inputs are inferred from + # the real dtype + assert isinstance(key_layer, query_layer.__class__) and isinstance( + value_layer, query_layer.__class__ + ), "q, k, and v must have the same type." + if not isinstance(query_layer, Float8Tensor): + query_layer, key_layer, value_layer = ( + QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] + ) + fa_3_optional_forward_kwargs["descale_q"] = ( + query_layer._scale_inv.unsqueeze(0) + ) + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( + 0 + ) + fa_3_optional_forward_kwargs["descale_v"] = ( + value_layer._scale_inv.unsqueeze(0) ) - return out - - # "fp8_mha" decides outputs in fp8, while inputs are inferred from - # the real dtype - assert isinstance(key_layer, query_layer.__class__) and isinstance( - value_layer, query_layer.__class__ - ), "q, k, and v must have the same type." - if not isinstance(query_layer, Float8Tensor): query_layer, key_layer, value_layer = ( - QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] ) - fa_3_optional_forward_kwargs["descale_q"] = ( - query_layer._scale_inv.unsqueeze(0) - ) - fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( - 0 - ) - fa_3_optional_forward_kwargs["descale_v"] = ( - value_layer._scale_inv.unsqueeze(0) - ) - query_layer, key_layer, value_layer = ( - convert_to_torch_float8(x, torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) - try: - output, _ = func( + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_3_optional_forward_kwargs, + ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your flash-attn v3 (beta) installation as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise + + if fp8: + output = output.to(dtype=torch_orig_dtype) + if fp8 and fp8_meta["recipe"].fp8_mha: + O_quantizer = quantizers["scaling_fwd"][META_O] + output = O_quantizer(output) + else: + output = func( query_layer, key_layer, value_layer, *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, - **fa_3_optional_forward_kwargs, + **fa_optional_forward_kwargs, ) - except TypeError as e: - if _flash_attn_3_0_0_beta: - e.args = ( - e.args[0] - + ". Please update your flash-attn v3 (beta) installation as it " - + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, - ) + e.args[1:] - raise - - if fp8: - output = output.to(dtype=torch_orig_dtype) - if fp8 and fp8_meta["recipe"].fp8_mha: - O_quantizer = quantizers["scaling_fwd"][META_O] - output = O_quantizer(output) - else: - output = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) - if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type and inference_params is None: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format == "sbhd": @@ -7751,6 +7763,7 @@ def forward( ) if inference_params is not None: + inference_params.is_output_right_aligned = use_flash_attention output = inference_params.post_step(self.layer_number, output) return output diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9dc35e0d5a..25e11070fc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -43,7 +43,7 @@ void reshape_q( void reshape_o( torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len); + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned); void copy_to_kv_cache( torch::Tensor new_k, torch::Tensor new_v, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 5a518adf71..5ae2f19c5c 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -111,12 +111,15 @@ __global__ void reshape_o_kernel( scalar_t* output_buffer, int* step_lens, int h_o, int d_o, - int b, int max_seq_len) { + int b, int max_seq_len, bool is_output_right_aligned) { // output: bshd; output_buffer: thd; // step_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int num_elts = step_lens[batch_idx] * h_o * d_o; int output_offset = batch_idx * max_seq_len * h_o * d_o; + if (is_output_right_aligned) { + output_offset = ((batch_idx + 1) * max_seq_len - step_lens[batch_idx]) * h_o * d_o; + } int output_buffer_offset = 0; for (int t = 0; t < batch_idx; t ++) { output_buffer_offset += step_lens[t]; @@ -134,36 +137,36 @@ template void reshape_o_launcher( torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len) { + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(output.data_ptr()), reinterpret_cast(output_buffer.data_ptr()), step_lens.data_ptr(), - h_o, d_o, b, max_seq_len); + h_o, d_o, b, max_seq_len, is_output_right_aligned); } void reshape_o( torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len) { + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { NVTE_CHECK( output.scalar_type() == output_buffer.scalar_type(), "output and output_buffer must be of the same data type."); if (output.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); } else if (output.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); } else if (output.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); // } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { // using dtype = at::kFloat8_e4m3fn; -// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); +// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); // } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { // using dtype = at::kFloat8_e5m2; -// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len); +// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index f22313c972..5610f4f405 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -74,6 +74,14 @@ def __init__( # page table, [batch_size, max_pages_per_seq] self.page_table = None + def reset(self): + self.sequences = OrderedDict() + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + self.allocated_pages = defaultdict(list) + self.page_table.fill_(0) + def allocate_memory(self, layer_number): """Allocate memory for the KV cache""" k_cache = torch.empty( From 56c0c0701fe5d010f5368f47be5d1366a9c499a0 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 19 Feb 2025 02:40:07 -0800 Subject: [PATCH 096/239] [PyTorch] Fix typo (#1495) Fix typo Signed-off-by: Tim Moon Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index aa5964bc4a..fe023208d1 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -874,7 +874,7 @@ def _all_gather_fp8( dtype = input_.dtype device = input_.device out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - elif isinstance(input, Float8Tensor): + elif isinstance(input_, Float8Tensor): out = input_.make_like(input_, shape=out_shape) out._data = torch.empty_like( out_shape, From fceff07a59bacd517baaf6a0f9cb0fb087f117ea Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 20 Feb 2025 05:55:41 +0800 Subject: [PATCH 097/239] [PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488) * fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by: Xin Yao * fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by: Xin Yao * update tests Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 33 +++++++++++-- .../pytorch/module/grouped_linear.py | 49 ++++++++++--------- 2 files changed, 53 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 22735c5292..a72ba097a1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1400,7 +1400,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): +def _test_grouped_linear_accuracy( + block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation +): reset_rng_states() if fp8: FP8GlobalStateManager.reset() @@ -1447,7 +1449,11 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f outputs = [out, inp_hidden_states.grad] for p in block.parameters(): if p.requires_grad: - outputs.append(p.grad) + if getattr(p, "main_grad", None) is not None: + outputs.append(p.main_grad) + assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True + else: + outputs.append(p.grad) return outputs @@ -1458,8 +1464,17 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + parallel_mode=None, ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -1481,6 +1496,7 @@ def test_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() sequential_linear = torch.nn.ModuleList( [ @@ -1491,6 +1507,7 @@ def test_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() for _ in range(num_gemms) ] @@ -1501,12 +1518,16 @@ def test_grouped_linear_accuracy( for i in range(num_gemms): sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, recipe, fp8 + sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) outputs = _test_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) # Shoule be bit-wise match @@ -1527,6 +1548,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): recipe=recipe, fp8_model_params=True, parallel_mode=parallel_mode, + fuse_wgrad_accumulation=True, ) @@ -1541,6 +1563,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): fp8=True, recipe=recipe, fp8_model_params=True, + fuse_wgrad_accumulation=True, ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cab8dff7c2..10b21f25c6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -178,7 +178,6 @@ def forward( if is_grad_enabled: - saved_inputs, saved_weights = [], [] ctx.weights_shape_1 = weights[0].shape[1] tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) @@ -186,9 +185,11 @@ def forward( ctx.tensor_objects = tensor_objects ctx.weights_requires_grad = weights[0].requires_grad + if fuse_wgrad_accumulation and ctx.weights_requires_grad: + ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] + else: + ctx.main_grads = [None] * num_gemms ctx.device = device - ctx.saved_inputs = saved_inputs - ctx.saved_weights = saved_weights ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits ctx.num_gemms = num_gemms @@ -220,7 +221,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] biases = saved_tensors[2 * N : 3 * N] - main_grads = saved_tensors[3 * N :] + main_grads = ctx.main_grads if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO for i in ctx.num_gemms: @@ -281,31 +282,31 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation: - wgrad_list = [w.main_grad for w in weights] + wgrad_list = main_grads else: wgrad_list = [ torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - # WGRAD - _, grad_biases_, _ = general_grouped_gemm( - inputmats, - grad_output, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - m_splits=ctx.m_splits, - use_bias=ctx.use_bias if grad_biases[0] is None else None, - bias=biases, - use_split_accumulator=_2X_ACC_WGRAD, - accumulate=accumulate_wgrad_into_param_main_grad, - ) - for i in range(ctx.num_gemms): - if grad_biases[i] is None: - grad_biases[i] = grad_biases_[i] - del grad_biases_ + # WGRAD + _, grad_biases_, _ = general_grouped_gemm( + inputmats, + grad_output, + wgrad_list, + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ # Deallocate input tensor clear_tensor_data(*inputmats) From 33b430f9fb12eef75a8f67128812b9663222b28e Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Wed, 19 Feb 2025 15:19:17 -0800 Subject: [PATCH 098/239] commit two files missed by bcef6b34 Signed-off-by: Charlene Yang --- transformer_engine/pytorch/inference.py | 393 ++++++++++++++++++ .../pytorch/kv_cache_manager.py | 45 ++ 2 files changed, 438 insertions(+) create mode 100644 transformer_engine/pytorch/inference.py create mode 100644 transformer_engine/pytorch/kv_cache_manager.py diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py new file mode 100644 index 0000000000..a1d5526784 --- /dev/null +++ b/transformer_engine/pytorch/inference.py @@ -0,0 +1,393 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Inference.""" +import collections +from typing import Dict, List +from einops import rearrange + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat +from transformer_engine.pytorch.kv_cache_manager import KVCacheManager +from transformer_engine.pytorch.kv_cache_manager_paged import PagedKVCacheManager +from transformer_engine.pytorch.kv_cache_manager_non_paged import NonPagedKVCacheManager + +class InferenceParams: # pylint: disable=too-few-public-methods + """ + Inference parameters that are passed to the main model in order + to efficiently calculate and store the context and previously generated tokens + during inference. + + Parameters + ---------- + max_batch_size : int + maximum batch size during inference. + max_sequence_length : int + maximum sequence length during inference. + num_heads: int + number of attention heads in key/value tensor. + head_dim_k: int + head size for the key tensor. + dtype: torch.dtype + data type for the KV cache. + head_dim_v: Optional[int], default = None + head size for the value tensor. If None, it will be set to head_dim_k. + is_paged: bool, default = False + whether the KV cache is paged or non-paged (contiguous). + total_num_pages: Optional[int], default = None + total number of pages in the K cache or V cache if is_paged = True. + page_size: Optional[int], default = None + page size in number of tokens if is_paged = True. + """ + + def __init__( + self, + max_batch_size: int, + max_seqlen_kv: int, + num_heads_kv: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: int = None, + is_paged: bool = False, + total_num_pages: int = None, + page_size: int = None, + num_heads_q: int = None, + head_dim_q: int = None, + max_ctx_len: int = None, + qkv_format: str = "bshd", + cache_manager: KVCacheManager = None, + ): + self.max_batch_size = max_batch_size + self.max_seqlen_kv = max_seqlen_kv + self.num_heads_kv = num_heads_kv + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + self.is_paged = is_paged + + if not self.is_paged: + cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager + self.cache_manager = cls( + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + head_dim_v=self.head_dim_v, + ) + else: + assert page_size is not None, "Paged KV cache requires page_size is not None." + assert ( + max_seqlen_kv % page_size == 0 + ), "Paged KV cache requires max_seqlen_kv % page_size = 0." + max_pages_per_seq = max_seqlen_kv // page_size + assert ( + total_num_pages == self.max_batch_size * max_pages_per_seq + ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." + self.page_size = page_size + self.max_seqlen_kv = max_seqlen_kv + self.total_num_pages = total_num_pages + + cls = cache_manager if cache_manager is not None else PagedKVCacheManager + self.cache_manager = cls( + total_num_pages=self.total_num_pages, + page_size=self.page_size, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + head_dim_v=self.head_dim_v, + ) + + if qkv_format == "thd": + # query will be converted to 'bshd' to be consistent with cache format + assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" + assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" + assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" + self.num_heads_q = num_heads_q + self.head_dim_q = head_dim_q + self.max_ctx_len = max_ctx_len + self.max_seqlen_q = max_ctx_len + + # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache + self.cache_qkv_format = "bshd" + self.input_qkv_format = qkv_format + + self.sequences_prev = collections.OrderedDict() + self.sequences = collections.OrderedDict() + self.step_dict = collections.OrderedDict() + self.batch_size = 0 + + self.cu_seqlens_q = None + self.cu_seqlens_kv = None + + # original q will be used as the output buffer + self.q_orig = {} + # convert q to 'bshd' to be consistent with cache format + self.q_buffer = {} + + self.is_output_right_aligned = False + + def reset(self): + """ + Reset the state of InferenceParams. + """ + self.sequences = collections.OrderedDict() + self.cache_manager.reset() + if self.input_qkv_format == 'thd': + for layer_number in self.q_buffer: + self.q_buffer[layer_number].fill_(0) + + def __repr__(self) -> str: + if self.is_paged: + return ( + f"dtype={self.dtype}, " + f"is_paged={self.is_paged}, " + f"total_pages={self.total_num_pages}, " + f"page_size={self.page_size}, " + f"num_heads={self.num_heads_kv}, " + f"head_dim_k={self.head_dim_k}, " + f"head_dim_v={self.head_dim_v}" + ) + return ( + f"dtype={self.dtype}, " + f"is_paged={self.is_paged}, " + f"max_batch_size={self.max_batch_size}, " + f"max_seqlen={self.max_seqlen_kv}, " + f"num_heads={self.num_heads_kv}, " + f"head_dim_k={self.head_dim_k}, " + f"head_dim_v={self.head_dim_v}" + ) + + def allocate_memory(self, layer_number: int, qkv_format: str): + """ + Allocate memory for the KV cache for the layer #layer_number. + Both K cache and V cache are in 'bshd' format. + - non-paged: + - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] + - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] + - paged: + - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] + - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] + If is_cuda_graph = True, several buffers are also allocated. + - Q buffer: [max_batch_size, max_seqlen_kv, num_heads_q, head_dim_q] + - cu_seqlens_q buffer: [max_batch_size + 1] + - cu_seqlens_kv buffer: [max_batch_size + 1] + """ + self.cache_manager.allocate_memory(layer_number) + + if qkv_format == 'thd': + self.q_buffer[layer_number] = torch.zeros( + self.max_batch_size, + self.max_ctx_len, + self.num_heads_q, + self.head_dim_q, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + + self.cu_seqlens_q = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_kv = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def pre_step( + self, + step_dict: Dict[List, List], + ): + """ + Prepare for step(). + """ + self.step_dict = step_dict + self.batch_size = len(step_dict) + self.sequences_prev = self.sequences + self.sequences = self.cache_manager.pre_step(step_dict) + + actual_batch_size = len(step_dict) + seqlens_q = list(step_dict.values()) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) + self.cu_seqlens_q.copy_( + torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") + ) + seq_lens = list(self.sequences.values()) + cu_seqlens_kv = [0] + [sum(seq_lens[:i]) for i in range(1, actual_batch_size + 1)] + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * (self.max_batch_size - actual_batch_size) + self.cu_seqlens_kv.copy_( + torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") + ) + + def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): + """ + Convert the k cache and v cache from paged to non-paged format. This function + can be used for debugging purposes or for backends that do not have paged attention + support yet, for example, UnfusedDotProductAttention. + + It can be called after step(). Based on the page table, it re-indexes the cache + tensors and returns the contiguous, non-paged, key and value tensors. The kv cache tensors + are assumed to be in 'bshd' format (see self.allocate_memory), and the returned key and + value tensors will be in :attr:`qkv_format` to be consistent with the original inputs. + + Parameters + ---------- + layer_number: int + The layer number of the kv cache + qkv_format: str + The format of the returned key and value tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Non-paged key cache tensor + v_cache: torch.Tensor + Non-paged value cache tensor + """ + k_cache, v_cache = self.cache_manager.cache[layer_number] + page_table = self.cache_manager.page_table + batch_size = page_table.shape[0] + actual_batch_size = len(self.step_dict) + seqlens = list(self.sequences.values()) + new_k_cache = rearrange( + k_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) + new_v_cache = rearrange( + v_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) + if qkv_format == "thd": + new_k_cache = new_k_cache.contiguous() + new_v_cache = new_v_cache.contiguous() + else: + new_k_cache = new_k_cache[:actual_batch_size].contiguous() + new_v_cache = new_v_cache[:actual_batch_size].contiguous() + return new_k_cache, new_v_cache + + def step( + self, + layer_number: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str, + ): + """ + Update KV cache with the new key/value tokens for a given inference iteration. + + NonPagedKVCacheManager and PagedKVCacheManager are two examples of the cache manager. + Users can write their own cache manager with their own step() function. + + If the inference iteration has only generation sequences, :attr:`k` and :attr:`v` tensors + should have shape: + - [batch_size, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', + - [1, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and + - [batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. + + If the inference iteration has both generation sequences and context sequences, :attr:`k` + and :attr:`v` should be arranged in a way so that the sequences in generation phase come + before the sequences in context phase, in the tensor. They should have the following shape. + - [batch_size, max_seqlen, num_heads, head_dim] for :attr:`qkv_format` = 'bshd' + - [max_seqlen, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and + - [total_num_new_tokens, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. + Here, max_seqlen is the maximum sequence length for the new tokens in the batch, and it may + be smaller than InferenceParams.max_seqlen_kv. + + Take a batch of 4, with seq_ids = [0, 1, 2, 3], as an example. At iteration t, all 4 sequences + are processed, after which, sequence 2 is determined to be 'finished'. For iteration t+1, there + may or may not be a new sequence added to the batch. + + If no new sequence is added, input tensors :attr:`k` and :attr:`v` should have shape + [3, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', [1, 3, num_heads, head_dim] for + :attr:`qkv_format` = 'sbhd', and [3, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. + + If one new sequence is added, for example, sequence 8 with 10 context tokens, then input tensors + :attr:`k` and :attr:`v` should be in [4, 10, num_heads, head_dim] shape if + :attr:`qkv_format` = 'bshd', [10, 4, num_heads, head_dim] if :attr:`qkv_format` = 'sbhd', + or [13, num_heads, head_dim] if :attr:`qkv_format` = 'thd'. + + Parameters + ---------- + layer_number: int + The layer number of the kv cache + k: torch.Tensor + The new key tokens for the current iteration + v: torch.Tensor + The new value tokens for the current iteration + qkv_format: str + The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + The key cache tensor, containing tokens from both previous and current iterations + v_cache: torch.Tensor + The value cache tensor, containing tokens from both previous and current iterations + page_table: torch.Tensor + The page table if is_paged = True; else `None` + """ + self.input_qkv_format = qkv_format + + if qkv_format == "bshd": + q_buffer = q.contiguous() + self.max_seqlen_q = q_buffer.shape[1] + if qkv_format == "sbhd": + q_buffer = q.transpose(0, 1).contiguous() + self.max_seqlen_q = q_buffer.shape[1] + if qkv_format == "thd": + self.q_orig[layer_number] = q + self.max_seqlen_q = self.max_ctx_len + + q_buffer = self.q_buffer[layer_number] + step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + ctx_len = 1 + if qkv_format == "bshd": + ctx_len = q.shape[1] + if qkv_format == "sbhd": + ctx_len = q.shape[0] + tex.reshape_q( + q, q_buffer, step_lens, + QKVFormat[qkv_format], + self.num_heads_q, self.head_dim_q, + self.max_batch_size, ctx_len, self.max_ctx_len) + + k_cache, v_cache, page_table = self.cache_manager.step( + layer_number, k, v, self.cu_seqlens_q, self.cu_seqlens_kv, qkv_format, + ) + + return q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, self.max_seqlen_q, self.max_seqlen_kv, self.cache_qkv_format + + def post_step( + self, + layer_number: int, + output: torch.Tensor, + ): + """ + Process the attention output in order to return it in the original qkv_format. + """ + if self.input_qkv_format == "bshd": + output = output[:self.batch_size, :self.max_seqlen_q].contiguous() + if self.input_qkv_format == "sbhd": + output = output[:self.batch_size, :self.max_seqlen_q].transpose(0, 1).contiguous() + if self.input_qkv_format == "thd": + #print('oooo ', output[:2, :, :4]) + output_buffer = self.q_orig[layer_number] + step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + tex.reshape_o(output, output_buffer, step_lens, + self.num_heads_q, self.head_dim_q, self.batch_size, self.max_ctx_len, self.is_output_right_aligned) + output = output_buffer.view(output_buffer.shape[0], -1) + + return output + + diff --git a/transformer_engine/pytorch/kv_cache_manager.py b/transformer_engine/pytorch/kv_cache_manager.py new file mode 100644 index 0000000000..072919821f --- /dev/null +++ b/transformer_engine/pytorch/kv_cache_manager.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""KV Cache Manager.""" +from collections import OrderedDict +from typing import Dict, List + +import torch + +class KVCacheManager: + """ + KV cache manager. The base class for custom cache managers. + """ + def __init__(self, *args, **kwargs): + """Initialize the cache manager.""" + self.cache = {} + self.sequences = OrderedDict() + + def reset(self): + """Empty tracked sequences""" + self.sequences = OrderedDict() + + def allocate_memory(self, layer_number: int): + """Allocate memory for the KV cache.""" + self.cache[layer_number] = (None, None) + + def pre_step( + self, + step_dict: Dict[List, List], + ): + """Prepare for operations in step(). Update sequences with step_dict.""" + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + qkv_format: str, + ): + """Update the cache with new_k and new_v tokens""" + return *self.cache[layer_number], None From b612cdebde53c26b4aaa38907472bc6588cfc211 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 20 Feb 2025 13:00:09 +0530 Subject: [PATCH 099/239] Fix TE ops API compatibility with PyTorch versions < 2.4.3 (#1494) * Fix te sequential for older pytorch versions Signed-off-by: Kirthi Shankar Sivamani * FIxes Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/distributed/test_torch_fsdp2.py | 16 +++------------- transformer_engine/pytorch/ops/_common.py | 10 ++++++++-- transformer_engine/pytorch/utils.py | 7 +++++++ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 4298d17c9c..bad09bf32a 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -7,19 +7,9 @@ import subprocess from pathlib import Path from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -import torch -from packaging.version import Version as PkgVersion - - -def get_torch_version(): - """Get PyTorch version from __version__""" +from transformer_engine.pytorch.utils import torch_version - def get_torch_version_str(): - import torch - - return str(torch.__version__) - - return PkgVersion(get_torch_version_str()) +import torch fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -44,7 +34,7 @@ def _run_test(fp_init, sharding_dims): @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") -@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+") +@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) def test_distributed(fp8_init, sharding_dims): diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index bb826e552e..b4631eb9a7 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -16,6 +16,7 @@ canonicalize_device, canonicalize_dtype, devices_match, + torch_version, ) @@ -98,8 +99,13 @@ def maybe_autocast_dtype( default_dtype: Optional[torch.dtype] = None, ) -> torch.dtype: """Get autocast dtype if enabled""" - if torch.is_autocast_enabled(device_type): - return torch.get_autocast_dtype(device_type) + + if torch_version() >= (2, 4, 3): + if torch.is_autocast_enabled(device_type): + return torch.get_autocast_dtype(device_type) + else: + if torch.is_autocast_enabled(): + return torch.get_autocast_gpu_dtype() return canonicalize_dtype(default_dtype) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 1922a7e867..4678097dc4 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,6 +8,7 @@ import math import os from typing import Any, Callable, List, Optional, Tuple +from packaging.version import Version as PkgVersion import torch import transformer_engine.pytorch.cpp_extensions as ext @@ -386,3 +387,9 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: # Pop NVTX range torch.cuda.nvtx.range_pop() + + +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release From 257345a56d006bb24be890bfd813b1d1299807a8 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 20 Feb 2025 10:13:15 -0800 Subject: [PATCH 100/239] [PyTorch] Fix CP implementation with FP8 (#1483) * commit some debug code Signed-off-by: Xiaowei Ren * add more debug info Signed-off-by: Xiaowei Ren * debug code commit and typo fix Signed-off-by: Xiaowei Ren * a typo fix Signed-off-by: Xiaowei Ren * remove debug info Signed-off-by: Xiaowei Ren * do not return lse Signed-off-by: Xiaowei Ren * add amax_per_step for quantizers of CP Signed-off-by: Xiaowei Ren * fix FP8 + CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bug fix Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * dtype fix Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 262 +++++++++++++++--------- transformer_engine/pytorch/fp8.py | 2 +- 2 files changed, 166 insertions(+), 98 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8584431dc2..d6b9894fc3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1894,11 +1894,12 @@ def forward( fused_attn_backend = None qkv_dtype = q.dtype + amax_per_step = None + S_quantizer_per_step = [None for _ in range(cp_size)] + O_CP_quantizer_per_step = [None for _ in range(cp_size)] # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = False - if fp8: - is_output_fp8 = fp8_meta["recipe"].fp8_mha ( QKV_quantizer, @@ -1919,28 +1920,30 @@ def forward( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) - if not is_input_fp8: + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + if is_input_fp8: + QKV_quantizer = q._quantizer + q, k, v = q._data, k._data, v._data + else: q_f16, k_f16, v_f16 = q, k, v if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16) + q = QKV_quantizer(q_f16)._data if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]] - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer # partial result quantizer + k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # partial result quantizer + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i] + O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() + O_CP_quantizer_per_step[i].amax = amax_per_step[1][i] else: assert False, "FP8 is only supported with Fused Attention!" else: q_f16 = q if use_fused_attention: - fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if fp8: - q = q._data - k = k._data - v = v._data - if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) @@ -2067,7 +2070,7 @@ def forward( kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i]) + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data if causal: if i == 0: if pad_between_seqs_q: @@ -2120,6 +2123,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2130,6 +2134,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, @@ -2243,6 +2249,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2253,6 +2260,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -2385,6 +2394,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2395,6 +2405,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q // 2, @@ -2507,6 +2519,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2517,6 +2530,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -2595,7 +2610,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: - out_per_step[i - 1] = out_per_step[i - 1].dequantize() + out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) if i == 1: out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) @@ -2697,6 +2712,11 @@ def forward( elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + S_quantizer.amax = amax_cp_fwd[0] + O_CP_quantizer.amax = amax_cp_fwd[1] + out_fp8 = None out_f16 = out.to(qkv_dtype) @@ -2708,7 +2728,7 @@ def forward( if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, kv_save, out_save = q, kv, out_fp8._data elif fp8 and is_input_fp8: - q_save, kv_save, out_save = q, k, out_f16 + q_save, kv_save, out_save = q, kv, out_f16 else: q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 @@ -2737,7 +2757,6 @@ def forward( ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer - ctx.qkv_dtype = qkv_dtype ctx.cp_group_a2a = cp_group_a2a ctx.cp_size_a2a = cp_size_a2a @@ -2778,10 +2797,8 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - saved_tensors = ctx.saved_tensors - q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( - restore_from_saved(ctx.tensor_objects, saved_tensors) + restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) ) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] @@ -2843,39 +2860,59 @@ def backward(ctx, dout): dout_dtype = dout.dtype fused_attn_backend = None fused_attn_dqkv_dtype = None + amax_per_step = None + dP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) - dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + dqkv_fp8_torch_dtype = get_fp8_torch_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False + ) + dq_fp8 = torch.empty( + (cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device + ) + dkv_fp8 = torch.empty( + (cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device + ) dkv_fp8_ = torch.empty_like(dkv_fp8) if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - fused_attn_dqkv_dtype = dout._fp8_dtype - dout = dout._data + ctx.dO_quantizer = dout._quantizer else: dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = dout._fp8_dtype - dout = dout._data + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer - fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i] + dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() + dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i] else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - kv = ctx.QKV_quantizer.create_tensor_from_data( - kv, fake_dtype=ctx.qkv_dtype, internal=True - ) - q, kv = q.dequantize(), kv.dequantize() - if cp_size_a2a == 1: - dout = dout.dequantize() + if ctx.fp8_meta is not None: + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q = q.dequantize(dtype=ctx.qkv_dtype) + kv = kv.dequantize(dtype=ctx.qkv_dtype) + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + if cp_size_a2a == 1: + dout = dout.dequantize(dtype=dout_dtype) + else: + ctx.dO_quantizer = dout._quantizer + dout = dout._data dq = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2902,9 +2939,10 @@ def backward(ctx, dout): True, ) if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True) - dout = dout.dequantize() - dout = dout._data + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + dout = dout.dequantize(dtype=dout_dtype) out = out.view(*q.shape) dout = dout.view(*q.shape) @@ -3020,8 +3058,10 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3133,8 +3173,10 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, @@ -3250,8 +3292,10 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, @@ -3282,7 +3326,6 @@ def backward(ctx, dout): dq_ = dq_._data dk_ = dk_._data dv_ = dv_._data - else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -3333,20 +3376,22 @@ def backward(ctx, dout): if ctx.fp8: q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype + q_part, fake_dtype=ctx.qkv_dtype, internal=True ) k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype + k_part, fake_dtype=ctx.qkv_dtype, internal=True ) v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype + v_part, fake_dtype=ctx.qkv_dtype, internal=True ) out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype + out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3555,13 +3600,20 @@ def backward(ctx, dout): dkv.add_(dkv_) if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax = amax_cp_bwd[0] + ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1] if ctx.qkv_format in ["bshd", "sbhd"]: # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8) - dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8) - dq, dkv = [x.dequantize() for x in [dq, dkv]] + dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dq_fp8, fake_dtype=torch.float32, internal=True + ) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: @@ -3606,9 +3658,9 @@ def backward(ctx, dout): attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) # converting torch.uint8 to float8tensor if ctx.fp8 and ctx.is_input_fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") return ( @@ -4227,21 +4279,20 @@ def forward( # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = False - if fp8: - is_output_fp8 = fp8_meta["recipe"].fp8_mha QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) ) if fp8: if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: + QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): @@ -4350,31 +4401,24 @@ def forward( out = out_fp8._data else: out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=False + out, fake_dtype=qkv_dtype, internal=True ) - out_f16 = out_fp8.dequantize() + out_f16 = out_fp8.dequantize(dtype=qkv_dtype) out_ret = out_f16 else: out_ret = out - if fp8: - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - elif is_input_fp8: - q_fp8 = QKV_quantizer.create_tensor_from_data( - q, fake_dtype=qkv_dtype, internal=False - ) - k_fp8 = QKV_quantizer.create_tensor_from_data( - k, fake_dtype=qkv_dtype, internal=False - ) - v_fp8 = QKV_quantizer.create_tensor_from_data( - v, fake_dtype=qkv_dtype, internal=False - ) - q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out - else: - q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 - else: + if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out + else: + if is_input_fp8: + q_save, k_save, v_save = q, k, v + else: + q_save, k_save, v_save = q_f16, k_f16, v_f16 + if is_output_fp8: + out_save = out + else: + out_save = out_f16 tensors_to_save, tensor_objects = prepare_for_saving( q_save, @@ -4397,7 +4441,6 @@ def forward( ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer - ctx.qkv_dtype = qkv_dtype ctx.batch_size = batch_size ctx.cp_group = cp_group @@ -4436,27 +4479,24 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - dout_dtype = dout.dtype qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") + dout_dtype = dout.dtype fused_attn_backend = None fused_attn_dqkv_dtype = None if ctx.fp8: - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_dqkv_dtype = fp8_dtype_backward - if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - dout_fp8 = dout - dout = dout_fp8._data + ctx.dO_quantizer = dout._quantizer else: - dout_f16 = dout - dout = ctx.dO_quantizer(dout_f16)._data + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -4465,12 +4505,25 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]] + if ctx.fp8_meta is not None: + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + dout = dout._data + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + k = ctx.QKV_quantizer.create_tensor_from_data( + k, fake_dtype=ctx.qkv_dtype, internal=True + ) + v = ctx.QKV_quantizer.create_tensor_from_data( + v, fake_dtype=ctx.qkv_dtype, internal=True + ) + q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: @@ -4481,6 +4534,15 @@ def backward(ctx, dout): out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: + out = ctx.O_quantizer.create_tensor_from_data( + out, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + out = out.dequantize(dtype=ctx.qkv_dtype) + dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -4531,7 +4593,7 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) dq, dk, dv, _ = fused_attn_bwd( @@ -4602,11 +4664,17 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + dq = ctx.dQKV_quantizer.create_tensor_from_data( + dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dk = ctx.dQKV_quantizer.create_tensor_from_data( + dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dv = ctx.dQKV_quantizer.create_tensor_from_data( + dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) if not ctx.is_input_fp8: - dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] + dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 254bcf12e1..f788368112 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -56,7 +56,7 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch. fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): return torch.float8_e4m3fn - return torch.float8_e5m2fn + return torch.float8_e5m2 def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: From 1c31b68d02a0d0f527f495f4746d8e389c06165a Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 21 Feb 2025 15:28:21 -0800 Subject: [PATCH 101/239] WIP: thd_bshd_bshd Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 13 +- .../common/fused_attn/fused_attn.cpp | 174 ++++++++++--- .../fused_attn_f16_arbitrary_seqlen.cu | 232 +++++++++++------- .../fused_attn_f16_arbitrary_seqlen.h | 5 +- transformer_engine/common/fused_attn/utils.cu | 66 +++-- .../include/transformer_engine/fused_attn.h | 54 ++-- .../common/util/pybind_helper.h | 20 +- transformer_engine/pytorch/attention.py | 81 ++++-- transformer_engine/pytorch/constants.py | 14 +- .../pytorch/cpp_extensions/fused_attn.py | 44 +++- transformer_engine/pytorch/inference.py | 49 ++-- 11 files changed, 521 insertions(+), 231 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 35fa8984dc..60ee5e87ca 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -196,9 +196,9 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("dtype", [torch.float16])#param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() @@ -211,7 +211,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): # figure out supported backends inference_params_qkv_format = "bshd" if is_paged: - qkv_layout = "paged_kv_" + inference_params_qkv_format + "_2" + inference_params_qkv_format + qkv_layout = "paged_kv_" + "_".join([inference_params_qkv_format] * 3) else: qkv_layout = "_".join([inference_params_qkv_format] * 3) available_backends, fused_attn_backends = _get_attention_backends( @@ -356,7 +356,7 @@ def gen_data(): dtype=torch.int32, ) sample_kwargs["inference_params"] = inference_params - sample_kwargs["attn_mask_type"] = "padding_causal" + sample_kwargs["attn_mask_type"] = "padding" #_causal" sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv sample_kwargs["qkv_format"] = qkv_format @@ -485,7 +485,7 @@ def gen_data(): cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, - attn_mask_type="padding_causal", + attn_mask_type="padding", #_causal", max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, @@ -525,6 +525,9 @@ def gen_data(): rtol=tols[dtype], ) if qkv_format == "thd": + print('i ', i, seq, cu_seqlens_q) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], #line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6df4aee315..f7fc9149bd 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -37,15 +37,16 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD: - return NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD; - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD: - return NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD; + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -59,24 +60,68 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_SBHD; case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Layout::NVTE_T3HD: case NVTE_QKV_Layout::NVTE_TH3D: case NVTE_QKV_Layout::NVTE_THD_T2HD: case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_THD_THD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD: + return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_SBHD_2BSHD; + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_BSHD_2SBHD; + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_THD_2BSHD; + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_THD_2SBHD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// map NVTE_QKV_Layout to NVTE_QKV_Format for Q +NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + switch (qkv_format) { + case NVTE_QKV_Format::NVTE_SBHD: + case NVTE_QKV_Format::NVTE_SBHD_2BSHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_BSHD_2SBHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Format::NVTE_THD: + case NVTE_QKV_Format::NVTE_THD_2BSHD: + case NVTE_QKV_Format::NVTE_THD_2SBHD: + return NVTE_QKV_Format::NVTE_THD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// map NVTE_QKV_Layout to NVTE_QKV_Format for KV +NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + switch (qkv_format) { + case NVTE_QKV_Format::NVTE_SBHD: + case NVTE_QKV_Format::NVTE_BSHD_2SBHD: + case NVTE_QKV_Format::NVTE_THD_2SBHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_SBHD_2BSHD: + case NVTE_QKV_Format::NVTE_THD_2BSHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; default: NVTE_ERROR("qkv_layout not supported!"); @@ -95,6 +140,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const int sm_arch_ = cuda::sm_arch(device_id); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); auto cudnn_runtime_version = cudnnGetVersion(); @@ -218,12 +265,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} (cudnn_runtime_version >= 90500 && - (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD) && + layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + //max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) (cudnn_runtime_version >= 90600 && @@ -239,7 +285,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - cudnn_runtime_version >= 90600))) && + cudnn_runtime_version >= 90600)) || + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && + cudnn_runtime_version >= 90700)) && // sliding window // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && @@ -275,6 +324,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (supported_ragged_offset_size)) { flag_arb = true; } + flag_arb = true; if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } @@ -492,7 +542,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, @@ -504,6 +554,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = reinterpret_cast(page_table_k); + const Tensor *input_page_table_v = reinterpret_cast(page_table_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_KV = reinterpret_cast(KV); @@ -528,11 +580,40 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_KV->data.shape[0]; } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_KV->data.shape[0]; + page_size_k = input_KV->data.shape[1]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_KV->data.shape[1]; + page_size_k = input_KV->data.shape[0]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -554,10 +635,11 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( @@ -619,9 +701,12 @@ void nvte_fused_attn_bwd_kvpacked( } size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_KV->data.shape[0]; } @@ -720,9 +805,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } int64_t num_pages_k = 0; @@ -738,16 +826,19 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso max_pages_per_seq_v = input_page_table_v->data.shape[1]; } NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD) { - num_pages_k = input_K->data.shape[0]; - page_size_k = input_K->data.shape[1]; - num_pages_v = input_V->data.shape[0]; - page_size_v = input_V->data.shape[1]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD) { - num_pages_k = input_K->data.shape[1]; - page_size_k = input_K->data.shape[0]; - num_pages_v = input_V->data.shape[1]; - page_size_v = input_V->data.shape[0]; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_K->data.shape[0]; + page_size_k = input_K->data.shape[1]; + num_pages_v = input_V->data.shape[0]; + page_size_v = input_V->data.shape[1]; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_K->data.shape[1]; + page_size_k = input_K->data.shape[0]; + num_pages_v = input_V->data.shape[1]; + page_size_v = input_V->data.shape[0]; + } } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); @@ -833,9 +924,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 671efe396e..313ca416da 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -76,25 +76,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_bottom_right = false; } bool is_dropout = (is_training && dropout_probability != 0.0f); - bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + NVTE_QKV_Format q_format = nvte_get_q_format(layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); - bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); } // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; - if (is_ragged && cudnn_runtime_version >= 90600) { + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket b = max_b; - s_q = max_t_q; - s_kv = max_t_kv; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; @@ -194,7 +196,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_name("Q") .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); - if (is_ragged) { + if (is_ragged_q) { offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_q") .set_dim({b + 1, 1, 1, 1}) @@ -207,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_paged_kv) { K->set_dim({num_pages_k, hg, page_size_k, d_qk}); V->set_dim({num_pages_v, hg, page_size_v, d_v}); - } else if (is_ragged) { + } else if (is_ragged_kv) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -306,7 +308,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); - if (is_ragged) { + if (is_ragged_q) { offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_o") .set_dim({b + 1, 1, 1, 1}) @@ -316,7 +318,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged && cudnn_runtime_version >= 90600) { + if (is_ragged_q && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") @@ -340,9 +342,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) : std::make_tuple(nullptr, nullptr); - auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); - auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + auto offset_qo_tuple = is_ragged_q ? std::make_tuple(offset_q, offset_o) + : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) + : std::make_tuple(nullptr, nullptr); + auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -356,14 +360,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, page_table_tuple, - offset_qkvo_tuple, offset_s_tuple, dropout_tuple); + offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, - offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = + offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -375,11 +379,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); size_t seqlen_offsets_workspace_size = 0; - if (is_ragged) { - if (cudnn_runtime_version >= 90600) { - seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + if (is_ragged_q || is_ragged_kv) { + size_t count = 2 * ((size_t)is_ragged_q + (size_t)is_ragged_kv); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { - seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; } } if (workspace == nullptr) { @@ -420,16 +425,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[page_table_v] = devPtrPageTableV; } - if (is_ragged) { + if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; - void *devOffsetsQ = + void *devOffsetsQ = nullptr; + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; + void *devOffsetsO = nullptr; + if (is_ragged_q) { + devOffsetsQ = static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; - void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; - void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + } + if (is_ragged_kv) { + devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + } void *devOffsetsS = nullptr; - if (cudnn_runtime_version >= 90600) { + if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); @@ -437,11 +450,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); - variant_pack[offset_q] = devOffsetsQ; - variant_pack[offset_k] = devOffsetsK; - variant_pack[offset_v] = devOffsetsV; - variant_pack[offset_o] = devOffsetsO; - if (cudnn_runtime_version >= 90600) { + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (is_ragged_q && cudnn_runtime_version >= 90600) { variant_pack[offset_stats] = devOffsetsS; } } @@ -484,27 +501,29 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_bottom_right = false; } bool is_dropout = (dropout_probability != 0.0f); - bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + NVTE_QKV_Format q_format = nvte_get_q_format(layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); - bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD || - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); } // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; - if (is_ragged && cudnn_runtime_version >= 90600) { + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket b = max_b; - s_q = max_t_q; - s_kv = max_t_kv; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; } // We choose between 32-bit and 64-bit offsets depending on need. @@ -621,12 +640,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_name("dO") .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); - if (is_ragged) { + if (is_ragged_q) { offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_q") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + q->set_ragged_offset(offset_q); + o->set_ragged_offset(offset_o); + dO->set_ragged_offset(offset_o); + } + if (is_ragged_q) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -637,23 +666,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - q->set_ragged_offset(offset_q); k->set_ragged_offset(offset_k); v->set_ragged_offset(offset_v); - o->set_ragged_offset(offset_o); - dO->set_ragged_offset(offset_o); } stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("stats") .set_dim({b, h, s_q, 1}) .set_data_type(fe::DataType_t::FLOAT)); - if (is_ragged && cudnn_runtime_version >= 90600) { + if (is_ragged_q && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") @@ -679,8 +700,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); - if (is_ragged && cudnn_runtime_version >= 90600) { + if (is_ragged_q && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); + } + if (is_ragged_kv && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } @@ -747,8 +770,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); - if (is_ragged) { + if (is_ragged_q) { dQ->set_ragged_offset(offset_q); + } + if (is_ragged_kv) { dK->set_ragged_offset(offset_k); dV->set_ragged_offset(offset_v); } @@ -767,9 +792,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); - auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + auto offset_qo_tuple = is_ragged_q ? std::make_tuple(offset_q, offset_o) + : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) + : std::make_tuple(nullptr, nullptr); + auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -783,14 +810,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, - offset_qkvo_tuple, offset_s_tuple, dropout_tuple); + offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = + offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -802,11 +829,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); size_t seqlen_offsets_workspace_size = 0; - if (is_ragged) { - if (cudnn_runtime_version >= 90600) { - seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + if (is_ragged_q || is_ragged_kv) { + size_t count = 2 * ((size_t)is_ragged_q + (size_t)is_ragged_kv); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { - seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; } } if (workspace == nullptr) { @@ -855,16 +883,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } - if (is_ragged) { + if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; - void *devOffsetsQ = + void *devOffsetsQ = nullptr; + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; + void *devOffsetsO = nullptr; + if (is_ragged_q) { + devOffsetsQ = static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; - void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; - void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + } + if (is_ragged_kv) { + devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + } void *devOffsetsS = nullptr; - if (cudnn_runtime_version >= 90600) { + if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); @@ -872,11 +908,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); - variant_pack[offset_q] = devOffsetsQ; - variant_pack[offset_k] = devOffsetsK; - variant_pack[offset_v] = devOffsetsV; - variant_pack[offset_o] = devOffsetsO; - if (cudnn_runtime_version >= 90600) { + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (is_ragged_q && cudnn_runtime_version >= 90600) { variant_pack[offset_stats] = devOffsetsS; } } @@ -1106,12 +1146,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + size_t num_pages_k, size_t num_pages_v, size_t page_size_k, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1119,7 +1162,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrQ = input_Q->data.dptr; void *devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; @@ -1145,13 +1189,19 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + void *devPtrPageTableK = page_table_k->data.dptr; + void *devPtrPageTableV = page_table_v->data.dptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } @@ -1161,7 +1211,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1179,7 +1229,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1214,10 +1264,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, + max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, + max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, nullptr, nullptr, devPtrSeqOffsetsQ, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1278,10 +1329,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } @@ -1343,7 +1399,8 @@ void fused_attn_arbitrary_seqlen_fwd( using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); void *devPtrQ = input_Q->data.dptr; void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; @@ -1368,9 +1425,13 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } @@ -1380,7 +1441,7 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1398,7 +1459,7 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1488,10 +1549,15 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 1925a07443..c6cc212211 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -38,12 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + size_t num_pages_k, size_t num_pages_v, size_t page_size_k, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index c395deea9b..daf4ce71c1 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -117,7 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 } break; case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) || @@ -224,8 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 break; case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { strideA[batch_dim_idx] = s_q * h * d; @@ -246,7 +247,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_transpose_dim_idx] = 1; } break; - case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { strideA[batch_dim_idx] = s_kv * h * d; @@ -267,8 +268,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; - case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD: - case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { strideA[batch_dim_idx] = h * d; @@ -425,28 +427,42 @@ __device__ void cu_seqlens_padded_to_offsets_impl( size_t tid = blockIdx.x * blockDim.x + threadIdx.x; auto cu_seqlens_id = min(tid, actual_b); if (tid <= max_b) { - offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; if (offsets_s != nullptr) { offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; } - switch (layout_group) { - case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; - offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; - break; - case NVTE_QKV_Layout_Group::NVTE_3HD: - case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_k[tid] = offsets_q[cu_seqlens_id]; - offsets_v[tid] = offsets_q[cu_seqlens_id]; - break; - case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; - offsets_v[tid] = offsets_k[cu_seqlens_id]; - break; + if (offsets_q != nullptr && offsets_o != nullptr) { + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_3HD: + case NVTE_QKV_Layout_Group::NVTE_H3D: + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + break; + } + } + if (offsets_k != nullptr && offsets_v != nullptr) { + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_3HD: + case NVTE_QKV_Layout_Group::NVTE_H3D: + offsets_k[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_v[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; + break; + } } } } diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 09bd40d77a..9a6ab099ea 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -43,12 +43,14 @@ enum NVTE_QKV_Layout { NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ - NVTE_Paged_KV_BSHD_2BSHD = 15, /*!< Paged_KV_BSHD_2BSHD layout */ - NVTE_Paged_KV_BSHD_2SBHD = 16, /*!< Paged_KV_BSHD_2SBHD layout */ - NVTE_Paged_KV_SBHD_2BSHD = 17, /*!< Paged_KV_SBHD_2BSHD layout */ - NVTE_Paged_KV_SBHD_2SBHD = 18, /*!< Paged_KV_SBHD_2SBHD layout */ - NVTE_Paged_KV_THD_2BSHD = 19, /*!< Paged_KV_THD_2BSHD layout */ - NVTE_Paged_KV_THD_2SBHD = 20, /*!< Paged_KV_THD_2SBHD layout */ + NVTE_THD_BSHD_BSHD = 15, /*!< THD_BSHD_BSHD layout */ + NVTE_THD_SBHD_SBHD = 16, /*!< THD_SBHD_SBHD layout */ + NVTE_Paged_KV_BSHD_BSHD_BSHD = 17, /*!< Paged_KV_BSHD_BSHD_BSHD layout */ + NVTE_Paged_KV_BSHD_SBHD_SBHD = 18, /*!< Paged_KV_BSHD_SBHD_SBHD layout */ + NVTE_Paged_KV_SBHD_BSHD_BSHD = 19, /*!< Paged_KV_SBHD_BSHD_BSHD layout */ + NVTE_Paged_KV_SBHD_SBHD_SBHD = 20, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ + NVTE_Paged_KV_THD_BSHD_BSHD = 21, /*!< Paged_KV_THD_BSHD_BSHD layout */ + NVTE_Paged_KV_THD_SBHD_SBHD = 22, /*!< Paged_KV_THD_SBHD_SBHD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -65,22 +67,28 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_H2D = 3, /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ NVTE_HD_HD_HD = 4, - /*! Paged_KV_2BSHD QKV layouts, e.g. Paged_KV_THD_2BSHD */ - NVTE_Paged_KV_2BSHD = 5, - /*! Paged_KV_2SBHD QKV layouts, e.g. Paged_KV_BSHD_2SBHD */ - NVTE_Paged_KV_2SBHD = 6, + /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ + NVTE_Paged_KV_HD_HD_HD = 5, }; /*! \enum NVTE_QKV_Format * \brief QKV formats */ enum NVTE_QKV_Format { - /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_2BSHD, Paged_KV_SBHD_2SBHD */ + /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_SBHD_SBHD */ NVTE_SBHD = 0, - /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_2BSHD, Paged_KV_BSHD_2SBHD */ + /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_BSHD_BSHD */ NVTE_BSHD = 1, - /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD, Paged_KV_THD_2BSHD, Paged_KV_THD_2SBHD */ + /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ NVTE_THD = 2, + /*! BSHD format for Q and SBHD format for KV, i.e. Paged_KV_BSHD_SBHD_SBHD */ + NVTE_BSHD_2SBHD = 3, + /*! SBHD format for Q and BSHD format for KV, i.e. Paged_KV_SBHD_BSHD_BSHD */ + NVTE_SBHD_2BSHD = 4, + /*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */ + NVTE_THD_2BSHD = 5, + /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ + NVTE_THD_2SBHD = 6, }; /*! \enum NVTE_Bias_Type @@ -145,6 +153,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Get Q format for a given QKV layout. + * + * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * + * \return q format, e.g. sbhd. + */ +NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); + +/*! \brief Get KV format for a given QKV layout. + * + * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * + * \return kv format, e.g. sbhd. + */ +NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] q_dtype The data type of Tensor Q. @@ -322,6 +346,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k]. + * \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -343,7 +369,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 768fd2797a..31cb2007bb 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -39,7 +39,11 @@ pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \ + .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ + .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ + .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -56,12 +60,14 @@ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ - .value("NVTE_Paged_KV_BSHD_2BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD) \ - .value("NVTE_Paged_KV_BSHD_2SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD) \ - .value("NVTE_Paged_KV_SBHD_2BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD) \ - .value("NVTE_Paged_KV_SBHD_2SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD) \ - .value("NVTE_Paged_KV_THD_2BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD) \ - .value("NVTE_Paged_KV_THD_2SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD); \ + .value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \ + .value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 944930738e..70b3c8e013 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5431,10 +5431,22 @@ def get_qkv_layout( v: torch.Tensor Value tensor. It may be different from input `v` as we try to fit tensors to a supported layout. + q_format: str + Format of the query tensor, {`bshd`, `sbhd`, `thd`}. + kv_format: str + Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. """ check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" + if "_2" in qkv_format: + q_format, kv_format = qkv_format.split("_2") + is_same_q_kv_format = False + else: + q_format = qkv_format + kv_format = qkv_format + is_same_q_kv_format = True + print('qkv format', qkv_format, is_same_q_kv_format, q_format, kv_format) def run_iteratively(q, k, v): # check data pointers @@ -5516,12 +5528,26 @@ def run_iteratively(q, k, v): check_strides_kv and check_shapes_kv and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) + and is_same_q_kv_format ): # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd # three chunks of memory, q, k and v, which may be disjoint or consecutive, and # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = "_".join(list([qkv_format]) * 3) + print('xxxxx0') + elif ( + check_strides_kv + and check_shapes_kv + and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) + and not is_same_q_kv_format + ): + # sbhd_bshd_bshd, bshd_sbhd_sbhd, thd_bshd_bshd, thd_sbhd_sbhd + # three chunks of memory, q, k and v, which may be disjoint or consecutive, and + # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or + # check_ptrs_qk=True or check_ptrs_kv=True + qkv_layout = q_format + "_" + kv_format + "_" + kv_format + print('xxxxx1') else: qkv_layout = "not_supported" @@ -5535,7 +5561,7 @@ def run_iteratively(q, k, v): if qkv_layout == "not_supported": raise RuntimeError("The provided qkv memory layout is not supported!") - return qkv_layout, q, k, v + return qkv_layout, q, k, v, q_format, kv_format def check_set_window_size( @@ -7439,6 +7465,24 @@ def forward( # max_seqlen_q = inference_params.max_seqlen_q # max_seqlen_kv = inference_params.max_seqlen_kv + if ( + isinstance(query_layer, Float8Tensor) + and isinstance(key_layer, Float8Tensor) + and isinstance(value_layer, Float8Tensor) + ): + qkv_layout, query_layer._data, key_layer._data, value_layer._data, q_format, kv_format = get_qkv_layout( + query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format + ) + else: + qkv_layout, query_layer, key_layer, value_layer, q_format, kv_format = get_qkv_layout( + query_layer, key_layer, value_layer, qkv_format=qkv_format + ) + # convert qkv layout to its corresponding paged attention layout + if inference_params is not None and inference_params.is_paged: + #qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format + #qkv_layout = "paged_kv_thd_2bshd"# + qkv_format + "_2" + qkv_format + qkv_layout = "paged_kv_" + qkv_layout + cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -7447,26 +7491,35 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - if qkv_format in ["sbhd", "bshd"]: + if q_format in ["sbhd", "bshd"]: max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size - if cu_seqlens_q is None or cu_seqlens_kv is None: + if cu_seqlens_q is None: if "padding" in attn_mask_type: assert ( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": cu_seqlens_q = get_cu_seqlens(attention_mask) - cu_seqlens_kv = cu_seqlens_q else: cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) else: cu_seqlens_q = _get_full_cu_seqlens( batch_size, max_seqlen_q, query_layer.device, ) + if kv_format in ["sbhd", "bshd"]: + max_seqlen_kv *= cp_size + if cu_seqlens_kv is None: + if "padding" in attn_mask_type: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + if self.attention_type == "self": + cu_seqlens_kv = get_cu_seqlens(attention_mask) + else: + cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + else: cu_seqlens_kv = _get_full_cu_seqlens( batch_size, max_seqlen_kv, @@ -7477,22 +7530,6 @@ def forward( #print('cu_seqlens_q ', cu_seqlens_q) #print('cu_seqlens_kv ', cu_seqlens_kv) - if ( - isinstance(query_layer, Float8Tensor) - and isinstance(key_layer, Float8Tensor) - and isinstance(value_layer, Float8Tensor) - ): - qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout( - query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format - ) - else: - qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout( - query_layer, key_layer, value_layer, qkv_format=qkv_format - ) - # convert qkv layout to its corresponding paged attention layout - if inference_params is not None and inference_params.is_paged: - qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format - global _alibi_cache if alibi_slopes is not None: assert ( diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 3b590def22..e4dd0772cc 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -54,12 +54,14 @@ "thd_t2hd", "thd_th2d", "thd_thd_thd", - "paged_kv_bshd_2bshd", - "paged_kv_bshd_2sbhd", - "paged_kv_sbhd_2bshd", - "paged_kv_sbhd_2sbhd", - "paged_kv_thd_2bshd", - "paged_kv_thd_2sbhd", + "thd_bshd_bshd", + "thd_sbhd_sbhd", + "paged_kv_bshd_bshd_bshd", + "paged_kv_bshd_sbhd_sbhd", + "paged_kv_sbhd_bshd_bshd", + "paged_kv_sbhd_sbhd_sbhd", + "paged_kv_thd_bshd_bshd", + "paged_kv_thd_sbhd_sbhd", ) LayerTypes = ("encoder", "decoder") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 6b194f963a..34fa598df8 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -36,6 +36,10 @@ "bshd": NVTE_QKV_Format.NVTE_BSHD, "sbhd": NVTE_QKV_Format.NVTE_SBHD, "thd": NVTE_QKV_Format.NVTE_THD, + "sbhd_2bshd": NVTE_QKV_Format.NVTE_SBHD_2BSHD, + "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, + "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, + "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, } QKVLayout = { @@ -54,12 +58,14 @@ "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, - "paged_kv_bshd_2bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_2BSHD, - "paged_kv_bshd_2sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_2SBHD, - "paged_kv_sbhd_2bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_2BSHD, - "paged_kv_sbhd_2sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_2SBHD, - "paged_kv_thd_2bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_2BSHD, - "paged_kv_thd_2sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_2SBHD, + "thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD, + "thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD, + "paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD, + "paged_kv_bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_SBHD_SBHD, + "paged_kv_sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_BSHD_BSHD, + "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, + "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, + "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, } AttnBiasType = { @@ -268,6 +274,32 @@ def fused_attn_fwd( # execute kernel + print(max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + dropout, + fast_zero_fill, + QKVLayout[qkv_layout], + AttnBiasType[attn_bias_type], + AttnMaskType[attn_mask_type], + window_size, + cu_seqlens_q, + cu_seqlens_kv, + q.shape, + k.shape, + v.shape, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + attn_bias, + rng_gen, + rng_elts_per_thread, + ) output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index a1d5526784..25b13a9e81 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -116,6 +116,7 @@ def __init__( # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format self.sequences_prev = collections.OrderedDict() self.sequences = collections.OrderedDict() @@ -221,6 +222,7 @@ def pre_step( torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") ) seq_lens = list(self.sequences.values()) + #seq_lens = [self.max_seqlen_kv] * self.batch_size cu_seqlens_kv = [0] + [sum(seq_lens[:i]) for i in range(1, actual_batch_size + 1)] cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * (self.max_batch_size - actual_batch_size) self.cu_seqlens_kv.copy_( @@ -338,6 +340,7 @@ def step( The page table if is_paged = True; else `None` """ self.input_qkv_format = qkv_format + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format if qkv_format == "bshd": q_buffer = q.contiguous() @@ -346,27 +349,28 @@ def step( q_buffer = q.transpose(0, 1).contiguous() self.max_seqlen_q = q_buffer.shape[1] if qkv_format == "thd": - self.q_orig[layer_number] = q - self.max_seqlen_q = self.max_ctx_len - - q_buffer = self.q_buffer[layer_number] - step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - ctx_len = 1 - if qkv_format == "bshd": - ctx_len = q.shape[1] - if qkv_format == "sbhd": - ctx_len = q.shape[0] - tex.reshape_q( - q, q_buffer, step_lens, - QKVFormat[qkv_format], - self.num_heads_q, self.head_dim_q, - self.max_batch_size, ctx_len, self.max_ctx_len) + q_buffer = q + # self.q_orig[layer_number] = q + # self.max_seqlen_q = self.max_ctx_len + + # q_buffer = self.q_buffer[layer_number] + # step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + # ctx_len = 1 + # if qkv_format == "bshd": + # ctx_len = q.shape[1] + # if qkv_format == "sbhd": + # ctx_len = q.shape[0] + # tex.reshape_q( + # q, q_buffer, step_lens, + # QKVFormat[qkv_format], + # self.num_heads_q, self.head_dim_q, + # self.max_batch_size, ctx_len, self.max_ctx_len) k_cache, v_cache, page_table = self.cache_manager.step( layer_number, k, v, self.cu_seqlens_q, self.cu_seqlens_kv, qkv_format, ) - return q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, self.max_seqlen_q, self.max_seqlen_kv, self.cache_qkv_format + return q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, self.max_seqlen_q, self.max_seqlen_kv, self.output_qkv_format def post_step( self, @@ -381,12 +385,13 @@ def post_step( if self.input_qkv_format == "sbhd": output = output[:self.batch_size, :self.max_seqlen_q].transpose(0, 1).contiguous() if self.input_qkv_format == "thd": - #print('oooo ', output[:2, :, :4]) - output_buffer = self.q_orig[layer_number] - step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - tex.reshape_o(output, output_buffer, step_lens, - self.num_heads_q, self.head_dim_q, self.batch_size, self.max_ctx_len, self.is_output_right_aligned) - output = output_buffer.view(output_buffer.shape[0], -1) + print('oooo ', output.shape) + print(output[:2, :4]) + #output_buffer = self.q_orig[layer_number] + #step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + #tex.reshape_o(output, output_buffer, step_lens, + # self.num_heads_q, self.head_dim_q, self.batch_size, self.max_ctx_len, self.is_output_right_aligned) + #output = output_buffer.view(output_buffer.shape[0], -1) return output From 7331a4c7e4a59f421a4a9dc378ae44cf1990f77b Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 21 Feb 2025 16:02:01 -0800 Subject: [PATCH 102/239] WIP: fix last commit Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 6 ++-- .../fused_attn_f16_arbitrary_seqlen.cu | 32 ++++++++++--------- transformer_engine/pytorch/inference.py | 2 +- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 60ee5e87ca..752d3f70a9 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -525,9 +525,9 @@ def gen_data(): rtol=tols[dtype], ) if qkv_format == "thd": - print('i ', i, seq, cu_seqlens_q) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[cu_seqlens_q[i + 1] - 1, :4]) + #print('i ', i, seq, cu_seqlens_q) + #print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + #print(line_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], #line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 313ca416da..b5710eab86 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -428,22 +428,23 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; + void *devOffsets = + static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; void *devOffsetsQ = nullptr; - void *devOffsetsK = nullptr; - void *devOffsetsV = nullptr; void *devOffsetsO = nullptr; if (is_ragged_q) { - devOffsetsQ = - static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; } + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + devOffsetsS = static_cast(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -655,7 +656,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( o->set_ragged_offset(offset_o); dO->set_ragged_offset(offset_o); } - if (is_ragged_q) { + if (is_ragged_kv) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -886,22 +887,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; + void *devOffsets = + static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; void *devOffsetsQ = nullptr; - void *devOffsetsK = nullptr; - void *devOffsetsV = nullptr; void *devOffsetsO = nullptr; if (is_ragged_q) { - devOffsetsQ = - static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; } + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + devOffsetsS = static_cast(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 25b13a9e81..0d8b660369 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -386,7 +386,7 @@ def post_step( output = output[:self.batch_size, :self.max_seqlen_q].transpose(0, 1).contiguous() if self.input_qkv_format == "thd": print('oooo ', output.shape) - print(output[:2, :4]) + #print(output[:2, :4]) #output_buffer = self.q_orig[layer_number] #step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] #tex.reshape_o(output, output_buffer, step_lens, From 0341de79ddbf8f21fe061e34e107e7a3fdfd717f Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 21 Feb 2025 17:06:52 -0800 Subject: [PATCH 103/239] WIP: fix 1c31b68d Signed-off-by: Charlene Yang --- transformer_engine/common/fused_attn/utils.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index daf4ce71c1..1c6072da07 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -434,6 +434,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl( offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: @@ -449,6 +450,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl( if (offsets_k != nullptr && offsets_v != nullptr) { switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; break; @@ -495,6 +497,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at std::array offsets_qkvo{}; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv; offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv; From 6bd61a7b72c5f48ac746b94b506d94f5ff290e25 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 21 Feb 2025 17:38:51 -0800 Subject: [PATCH 104/239] WIP: add bshd_2sbhd, sbhd_2bshd Signed-off-by: Charlene Yang --- .../common/fused_attn/fused_attn.cpp | 4 ++ transformer_engine/common/fused_attn/utils.cu | 2 + .../include/transformer_engine/fused_attn.h | 22 ++++---- .../common/util/pybind_helper.h | 2 + transformer_engine/pytorch/constants.py | 2 + .../pytorch/cpp_extensions/fused_attn.py | 54 ++++++++++--------- 6 files changed, 50 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index f7fc9149bd..0f31f60e41 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -37,6 +37,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; @@ -75,8 +77,10 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_THD_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: return NVTE_QKV_Format::NVTE_SBHD_2BSHD; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_BSHD_2SBHD; case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 1c6072da07..a079f5a8fc 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -247,6 +247,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_transpose_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { @@ -268,6 +269,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 9a6ab099ea..070e85789d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -43,14 +43,16 @@ enum NVTE_QKV_Layout { NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ - NVTE_THD_BSHD_BSHD = 15, /*!< THD_BSHD_BSHD layout */ - NVTE_THD_SBHD_SBHD = 16, /*!< THD_SBHD_SBHD layout */ - NVTE_Paged_KV_BSHD_BSHD_BSHD = 17, /*!< Paged_KV_BSHD_BSHD_BSHD layout */ - NVTE_Paged_KV_BSHD_SBHD_SBHD = 18, /*!< Paged_KV_BSHD_SBHD_SBHD layout */ - NVTE_Paged_KV_SBHD_BSHD_BSHD = 19, /*!< Paged_KV_SBHD_BSHD_BSHD layout */ - NVTE_Paged_KV_SBHD_SBHD_SBHD = 20, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ - NVTE_Paged_KV_THD_BSHD_BSHD = 21, /*!< Paged_KV_THD_BSHD_BSHD layout */ - NVTE_Paged_KV_THD_SBHD_SBHD = 22, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */ + NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */ + NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */ + NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */ + NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */ + NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */ + NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */ + NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ + NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ + NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -81,9 +83,9 @@ enum NVTE_QKV_Format { NVTE_BSHD = 1, /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ NVTE_THD = 2, - /*! BSHD format for Q and SBHD format for KV, i.e. Paged_KV_BSHD_SBHD_SBHD */ + /*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */ NVTE_BSHD_2SBHD = 3, - /*! SBHD format for Q and BSHD format for KV, i.e. Paged_KV_SBHD_BSHD_BSHD */ + /*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */ NVTE_SBHD_2BSHD = 4, /*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */ NVTE_THD_2BSHD = 5, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 31cb2007bb..b8c8df37ee 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -60,6 +60,8 @@ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ + .value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \ + .value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \ .value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \ .value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \ .value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \ diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index e4dd0772cc..8cbf5aea98 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -54,6 +54,8 @@ "thd_t2hd", "thd_th2d", "thd_thd_thd", + "sbhd_bshd_bshd", + "bshd_sbhd_sbhd", "thd_bshd_bshd", "thd_sbhd_sbhd", "paged_kv_bshd_bshd_bshd", diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 34fa598df8..a58cc4d0e9 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -58,6 +58,8 @@ "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, + "sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD, + "bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD, "thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD, "thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD, "paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD, @@ -274,32 +276,32 @@ def fused_attn_fwd( # execute kernel - print(max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens_q, - cu_seqlens_kv, - q.shape, - k.shape, - v.shape, - fake_dtype, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - page_table_k, - page_table_v, - s_quantizer, - o_quantizer, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) + #print(max_seqlen_q, + # max_seqlen_kv, + # is_training, + # attn_scale, + # dropout, + # fast_zero_fill, + # QKVLayout[qkv_layout], + # AttnBiasType[attn_bias_type], + # AttnMaskType[attn_mask_type], + # window_size, + # cu_seqlens_q, + # cu_seqlens_kv, + # q.shape, + # k.shape, + # v.shape, + # fake_dtype, + # cu_seqlens_q_padded, + # cu_seqlens_kv_padded, + # page_table_k, + # page_table_v, + # s_quantizer, + # o_quantizer, + # attn_bias, + # rng_gen, + # rng_elts_per_thread, + #) output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, From 2d30bb1271412066170ed15b79fd418e5993ae39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Feb 2025 01:41:00 +0000 Subject: [PATCH 105/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 152 +++++--- .../common/fused_attn/fused_attn.cpp | 33 +- .../fused_attn_f16_arbitrary_seqlen.cu | 63 ++-- .../fused_attn_f16_arbitrary_seqlen.h | 19 +- .../include/transformer_engine/fused_attn.h | 56 ++- transformer_engine/pytorch/attention.py | 139 ++++--- .../pytorch/cpp_extensions/fused_attn.py | 4 +- transformer_engine/pytorch/csrc/extensions.h | 33 +- .../pytorch/csrc/extensions/attention.cu | 340 ++++++++---------- transformer_engine/pytorch/graph.py | 17 +- transformer_engine/pytorch/inference.py | 60 ++-- .../pytorch/kv_cache_manager.py | 2 + .../pytorch/kv_cache_manager_non_paged.py | 43 ++- .../pytorch/kv_cache_manager_paged.py | 34 +- 14 files changed, 542 insertions(+), 453 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 752d3f70a9..b347340d3b 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -39,8 +39,10 @@ model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "infer_0": ModelConfig(4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16), - #"infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), + "infer_0": ModelConfig( + 4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 + ), + # "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -49,9 +51,11 @@ def to_pretty_string(x: torch.Tensor): return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]" + def round_up(a: int, b: int): return b * math.ceil(a / b) + class Simulation: def __init__( self, @@ -71,13 +75,13 @@ def __init__( self.max_gen_len = max_seq_len - self.max_ctx_len # simulate sequence ids in monotonically increasing fashion - self.seq_ids = torch.range(0, total_requests-1, dtype=torch.int32, device="cpu") + self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu") # simulate context lengths in Uniform distribution self.context_lens = torch.randint( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) - #self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -85,18 +89,18 @@ def __init__( gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to( dtype=torch.int32, device="cpu" ) - self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to( - dtype=torch.int32, device="cpu" - ) - #self.gen_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu") + # self.gen_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate arrival times in Poisson distribution if poisson_rate is None: self.poisson_rate = torch.randint(1, max_batch_size, [1]).item() interval_dist = Exponential(self.poisson_rate) arrival_intervals = interval_dist.sample((total_requests,)) - self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu") - #self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") + self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to( + dtype=torch.int32, device="cpu" + ) + # self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() # initialize tensors @@ -144,10 +148,10 @@ def print_step(self, logger): def print_summary(self, logger): logger.info("Summary:") logger.info(" {:<18s}: {}".format("total steps taken", self.t)) - logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times))) - logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times))) - logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens))) - logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times))) + logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times))) + logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times))) + logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens))) + logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times))) def add_new_seqs(self, new_seq_ids): # get ctx_lens for new seqs @@ -194,11 +198,11 @@ def step(self, dynamic_fill: bool = True): self.t_total_lens = self.t_ctx_lens + self.t_gen_lens -@pytest.mark.parametrize("dtype", [torch.float16])#param_types) +@pytest.mark.parametrize("dtype", [torch.float16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats) +@pytest.mark.parametrize("qkv_format", ["thd"]) # qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() @@ -253,7 +257,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): # generate data for all requests assert ( config.max_seqlen_q == config.max_seqlen_kv - ), "This test only simulates max_seqlen_q = max_seqlen_kv." + ), "This test only simulates max_seqlen_q = max_seqlen_kv." q = 0.1 * torch.randn( (config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk), dtype=dtype, @@ -297,7 +301,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): max_ctx_len=config.max_ctx_len, max_batch_size=max_batch_size, poisson_rate=2, - ) + ) sim.print_setup(logger) # initialize inference_params @@ -322,41 +326,45 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): if is_cuda_graph: t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu") - step_dict = OrderedDict( - zip(t_seq_ids.tolist(), step_lens.tolist()) - ) + step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist())) inference_params.pre_step(step_dict) if qkv_format == "bshd": - shape = [ config.batch_size, config.max_ctx_len] + shape = [config.batch_size, config.max_ctx_len] if qkv_format == "sbhd": - shape = [ config.max_ctx_len, config.batch_size] + shape = [config.max_ctx_len, config.batch_size] if qkv_format == "thd": - shape = [ config.batch_size * config.max_ctx_len] + shape = [config.batch_size * config.max_ctx_len] + def gen_data(): - return [torch.ones( - *shape, - config.num_heads, - config.head_dim_qk, - device="cuda", - dtype=dtype, - ) for _ in range(3)] + return [ + torch.ones( + *shape, + config.num_heads, + config.head_dim_qk, + device="cuda", + dtype=dtype, + ) + for _ in range(3) + ] sample_kwargs = {} - sample_kwargs["cu_seqlens_q"] = torch.linspace( 0, + sample_kwargs["cu_seqlens_q"] = torch.linspace( + 0, config.batch_size * config.max_ctx_len, - steps=config.batch_size+1, + steps=config.batch_size + 1, device="cuda", dtype=torch.int32, ) - sample_kwargs["cu_seqlens_kv"] = torch.linspace( 0, + sample_kwargs["cu_seqlens_kv"] = torch.linspace( + 0, config.batch_size * config.max_ctx_len, - steps=config.batch_size+1, + steps=config.batch_size + 1, device="cuda", dtype=torch.int32, ) sample_kwargs["inference_params"] = inference_params - sample_kwargs["attn_mask_type"] = "padding" #_causal" + sample_kwargs["attn_mask_type"] = "padding" # _causal" sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv sample_kwargs["qkv_format"] = qkv_format @@ -386,7 +394,7 @@ def gen_data(): max_tokens = config.batch_size * config.max_ctx_len while True: # prepare batch for the current step - dynamic_fill = True #inference_params.is_paged + dynamic_fill = True # inference_params.is_paged sim.step(dynamic_fill=dynamic_fill) sim.print_step(logger) @@ -427,9 +435,47 @@ def gen_data(): dim=0, ) if is_cuda_graph: - incremental_q = torch.cat([incremental_q, torch.zeros([max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], dtype=dtype, device=incremental_q.device)], dim=0) - incremental_k = torch.cat([incremental_k, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_k.device)], dim=0) - incremental_v = torch.cat([incremental_v, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_v.device)], dim=0) + incremental_q = torch.cat( + [ + incremental_q, + torch.zeros( + [max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], + dtype=dtype, + device=incremental_q.device, + ), + ], + dim=0, + ) + incremental_k = torch.cat( + [ + incremental_k, + torch.zeros( + [ + max_tokens - sum(sim.step_lens), + config.num_gqa_groups, + config.head_dim_v, + ], + dtype=dtype, + device=incremental_k.device, + ), + ], + dim=0, + ) + incremental_v = torch.cat( + [ + incremental_v, + torch.zeros( + [ + max_tokens - sum(sim.step_lens), + config.num_gqa_groups, + config.head_dim_v, + ], + dtype=dtype, + device=incremental_v.device, + ), + ], + dim=0, + ) else: incremental_q = torch.zeros( batch_size, @@ -472,9 +518,7 @@ def gen_data(): cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) - step_dict = OrderedDict( - zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()) - ) + step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist())) inference_params.pre_step(step_dict) if inference_params.is_paged: inference_params.cache_manager.print_cache() @@ -485,7 +529,7 @@ def gen_data(): cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, - attn_mask_type="padding", #_causal", + attn_mask_type="padding", # _causal", max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, qkv_format=qkv_format, @@ -508,8 +552,8 @@ def gen_data(): token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 if qkv_format == "bshd": torch.testing.assert_close( - #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - #line_output[:sim.step_lens[i] - 1, i, :], + # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + # line_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], line_output[i, token_index, :], atol=tols[dtype], @@ -517,20 +561,20 @@ def gen_data(): ) if qkv_format == "sbhd": torch.testing.assert_close( - #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - #line_output[:sim.step_lens[i] - 1, i, :], + # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + # line_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], line_output[token_index, i, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "thd": - #print('i ', i, seq, cu_seqlens_q) - #print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - #print(line_output[cu_seqlens_q[i + 1] - 1, :4]) + # print('i ', i, seq, cu_seqlens_q) + # print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + # print(line_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( - #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - #line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], + # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], + # line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], full_output[seq, sim.t_total_lens[i] - 1, :], line_output[cu_seqlens_q[i + 1] - 1, :], atol=tols[dtype], diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0f31f60e41..b117daeffd 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -290,8 +290,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window // pre-9.2: full attn, causal @@ -542,16 +544,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); @@ -641,10 +642,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, - input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, - handle); + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index b5710eab86..c3a650f251 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -342,10 +342,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) : std::make_tuple(nullptr, nullptr); - auto offset_qo_tuple = is_ragged_q ? std::make_tuple(offset_q, offset_o) - : std::make_tuple(nullptr, nullptr); - auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) - : std::make_tuple(nullptr, nullptr); + auto offset_qo_tuple = + is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = + is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); @@ -358,9 +358,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, page_table_tuple, - offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, + page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -439,12 +439,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsK = nullptr; void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; + devOffsetsK = + static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; + devOffsetsS = static_cast(devOffsets) + + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -793,10 +795,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_qo_tuple = is_ragged_q ? std::make_tuple(offset_q, offset_o) - : std::make_tuple(nullptr, nullptr); - auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) - : std::make_tuple(nullptr, nullptr); + auto offset_qo_tuple = + is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = + is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); @@ -898,12 +900,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsK = nullptr; void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; + devOffsetsK = + static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; + devOffsetsS = static_cast(devOffsets) + + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -1148,16 +1152,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, - size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1266,13 +1269,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, - max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index c6cc212211..e1a20274f4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -38,16 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, - size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 070e85789d..2e6be8d178 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -28,25 +28,25 @@ extern "C" { * different lengths in a batch. `Paged_KV`-based layouts are used for paged attention. */ enum NVTE_QKV_Layout { - NVTE_SB3HD = 0, /*!< SB3HD layout */ - NVTE_SBH3D = 1, /*!< SBH3D layout */ - NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ - NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ - NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ - NVTE_BS3HD = 5, /*!< BS3HD layout */ - NVTE_BSH3D = 6, /*!< BSH3D layout */ - NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ - NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ - NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ - NVTE_T3HD = 10, /*!< T3HD layout */ - NVTE_TH3D = 11, /*!< TH3D layout */ - NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ - NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ - NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ - NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */ - NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */ - NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */ - NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */ + NVTE_SB3HD = 0, /*!< SB3HD layout */ + NVTE_SBH3D = 1, /*!< SBH3D layout */ + NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ + NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ + NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ + NVTE_BS3HD = 5, /*!< BS3HD layout */ + NVTE_BSH3D = 6, /*!< BSH3D layout */ + NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ + NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ + NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ + NVTE_T3HD = 10, /*!< T3HD layout */ + NVTE_TH3D = 11, /*!< TH3D layout */ + NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ + NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ + NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */ + NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */ + NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */ + NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */ NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */ NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */ NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */ @@ -367,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 70b3c8e013..3a3dcc90b8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -530,21 +530,27 @@ def get_attention_backend( logger.debug("Disabling FlashAttention for FP8 KV caching") if use_fused_attention and inference_params.is_paged: use_fused_attention = False - logger.debug("Disabling FusedAttention as it does not support paged attention in FP8") + logger.debug( + "Disabling FusedAttention as it does not support paged attention in FP8" + ) if use_unfused_attention: use_unfused_attention = False logger.debug("Disabling UnfusedAttention as it does not support FP8 attention") else: if use_flash_attention and not _flash_attn_2_2_plus and not _use_flash_attn_3: use_flash_attention = False - logger.debug("Disabling FlashAttention as KV caching requires flash-attn 2.2+, or 3.0 (Hopper only)") + logger.debug( + "Disabling FlashAttention as KV caching requires flash-attn 2.2+, or 3.0" + " (Hopper only)" + ) if inference_params.is_paged: if use_fused_attention and cudnn_version < (9, 5, 0): logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+") use_fused_attention = False if use_flash_attention and not _use_flash_attn_3 and not _flash_attn_2_5_plus: logger.debug( - "Disabling FlashAttention as paged attention requires flash-attn 2.5+, or 3.0 (Hopper only)" + "Disabling FlashAttention as paged attention requires flash-attn 2.5+, or 3.0" + " (Hopper only)" ) use_flash_attention = False @@ -1042,6 +1048,7 @@ def get_attention_backend( available_backends, ) + @torch.no_grad() def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] @@ -1052,9 +1059,7 @@ def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq attention_mask_q = torch.cat( [ attention_mask_q, - torch.Tensor( - [False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i]) - ) + torch.Tensor([False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i])) .to(dtype=torch.bool) .unsqueeze(0) .unsqueeze(0) @@ -1065,10 +1070,7 @@ def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq attention_mask_kv = torch.cat( [ attention_mask_kv, - torch.Tensor( - [False] * seqlens_kv[i] - + [True] * (max_seqlen_kv - seqlens_kv[i]) - ) + torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i])) .to(dtype=torch.bool) .unsqueeze(0) .unsqueeze(0) @@ -1082,6 +1084,7 @@ def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq ) return attention_mask + @torch.no_grad() def get_full_mask( max_seqlen_q: int, @@ -3101,9 +3104,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_plus - ): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_backward_kwargs["window_size"] = (-1, 0) elif _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -3216,9 +3217,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_plus - ): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_backward_kwargs["window_size"] = (-1, -1) if _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -3334,9 +3333,7 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_plus - ): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -3883,9 +3880,7 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_plus - ): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] elif _flash_attn_2_7_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -5205,7 +5200,9 @@ def forward( ) if "padding" in attn_mask_type and qkv_format in ["bshd", "sbhd"]: - attention_mask = get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + attention_mask = get_attn_mask( + batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + ) attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( max_seqlen_q, max_seqlen_kv, @@ -5446,7 +5443,7 @@ def get_qkv_layout( q_format = qkv_format kv_format = qkv_format is_same_q_kv_format = True - print('qkv format', qkv_format, is_same_q_kv_format, q_format, kv_format) + print("qkv format", qkv_format, is_same_q_kv_format, q_format, kv_format) def run_iteratively(q, k, v): # check data pointers @@ -5535,7 +5532,7 @@ def run_iteratively(q, k, v): # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = "_".join(list([qkv_format]) * 3) - print('xxxxx0') + print("xxxxx0") elif ( check_strides_kv and check_shapes_kv @@ -5547,7 +5544,7 @@ def run_iteratively(q, k, v): # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = q_format + "_" + kv_format + "_" + kv_format - print('xxxxx1') + print("xxxxx1") else: qkv_layout = "not_supported" @@ -5733,8 +5730,8 @@ def forward( if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" - cu_seqlens_q = cu_seqlens_q[:batch_size+1] - cu_seqlens_kv = cu_seqlens_kv[:batch_size+1] + cu_seqlens_q = cu_seqlens_q[: batch_size + 1] + cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] if inference_params is None: # [b * s, h, d] @@ -5851,11 +5848,15 @@ def forward( if inference_params is not None: func = flash_attn_with_kvcache fa_optional_forward_kwargs_kvcache = {} - fa_optional_forward_kwargs_kvcache["cache_seqlens"] = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_optional_forward_kwargs_kvcache["cache_seqlens"] = ( + cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + ) fa_optional_forward_kwargs_kvcache["softmax_scale"] = self.softmax_scale fa_optional_forward_kwargs_kvcache["causal"] = "causal" in attn_mask_type if inference_params.is_paged: - fa_optional_forward_kwargs_kvcache["block_table"] = inference_params.cache_manager.page_table[:batch_size] + fa_optional_forward_kwargs_kvcache["block_table"] = ( + inference_params.cache_manager.page_table[:batch_size] + ) output = func( query_layer, key_layer, @@ -5906,8 +5907,8 @@ def convert_to_torch_float8(tensor, dtype): fa_3_optional_forward_kwargs["descale_q"] = ( query_layer._scale_inv.unsqueeze(0) ) - fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( - 0 + fa_3_optional_forward_kwargs["descale_k"] = ( + key_layer._scale_inv.unsqueeze(0) ) fa_3_optional_forward_kwargs["descale_v"] = ( value_layer._scale_inv.unsqueeze(0) @@ -5930,7 +5931,8 @@ def convert_to_torch_float8(tensor, dtype): if _flash_attn_3_0_0_beta: e.args = ( e.args[0] - + ". Please update your flash-attn v3 (beta) installation as it " + + ". Please update your flash-attn v3 (beta) installation" + " as it " + "may have added more supported arguments to its API. \n" + _flash_attn_3_installation_steps, ) + e.args[1:] @@ -5953,7 +5955,11 @@ def convert_to_torch_float8(tensor, dtype): **fa_optional_forward_kwargs, ) - if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type and inference_params is None: + if ( + qkv_format in ["sbhd", "bshd"] + and "padding" in attn_mask_type + and inference_params is None + ): output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format == "sbhd": @@ -7416,7 +7422,7 @@ def forward( # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference - #if "padding" not in attn_mask_type: + # if "padding" not in attn_mask_type: # attn_mask_type = "padding_" + attn_mask_type if attn_mask_type in ["causal", "padding_causal"]: attn_mask_type = attn_mask_type + "_bottom_right" @@ -7437,29 +7443,39 @@ def forward( # cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but # flash-attention and unfused attention will need q/k/v in the # same qkv_format - #target_qkv_format = inference_params.qkv_format - #query_layer = inference_params.reshape_and_copy_q( + # target_qkv_format = inference_params.qkv_format + # query_layer = inference_params.reshape_and_copy_q( # query_layer, qkv_format, target_qkv_format, self.layer_number - #) + # ) # update KV cache and return the full key/value tensors # full key/value tensors are in inference_params.qkv_format format - #print('query_layer',query_layer.shape, query_layer.dtype) - #print('query_layer', query_layer[8,0,:4]) - query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, qkv_format = inference_params.step( + # print('query_layer',query_layer.shape, query_layer.dtype) + # print('query_layer', query_layer[8,0,:4]) + ( + query_layer, + key_layer, + value_layer, + page_table, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + qkv_format, + ) = inference_params.step( self.layer_number, query_layer, key_layer, value_layer, qkv_format, ) - #print('ssss0 ',query_layer.shape, key_layer.shape, value_layer.shape) - #print('cu_seqlens_q',cu_seqlens_q) - #print('cu_seqlens_kv',cu_seqlens_kv) - #print('maxxxxx ',max_seqlen_q, max_seqlen_kv) + # print('ssss0 ',query_layer.shape, key_layer.shape, value_layer.shape) + # print('cu_seqlens_q',cu_seqlens_q) + # print('cu_seqlens_kv',cu_seqlens_kv) + # print('maxxxxx ',max_seqlen_q, max_seqlen_kv) # update cu_seqlens tensors - #if inference_params.is_cuda_graph: + # if inference_params.is_cuda_graph: # cu_seqlens_q = inference_params.cu_seqlens_q_buffer # cu_seqlens_kv = inference_params.cu_seqlens_kv_buffer # max_seqlen_q = inference_params.max_seqlen_q @@ -7470,17 +7486,24 @@ def forward( and isinstance(key_layer, Float8Tensor) and isinstance(value_layer, Float8Tensor) ): - qkv_layout, query_layer._data, key_layer._data, value_layer._data, q_format, kv_format = get_qkv_layout( + ( + qkv_layout, + query_layer._data, + key_layer._data, + value_layer._data, + q_format, + kv_format, + ) = get_qkv_layout( query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format ) else: - qkv_layout, query_layer, key_layer, value_layer, q_format, kv_format = get_qkv_layout( - query_layer, key_layer, value_layer, qkv_format=qkv_format + qkv_layout, query_layer, key_layer, value_layer, q_format, kv_format = ( + get_qkv_layout(query_layer, key_layer, value_layer, qkv_format=qkv_format) ) # convert qkv layout to its corresponding paged attention layout if inference_params is not None and inference_params.is_paged: - #qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format - #qkv_layout = "paged_kv_thd_2bshd"# + qkv_format + "_2" + qkv_format + # qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format + # qkv_layout = "paged_kv_thd_2bshd"# + qkv_format + "_2" + qkv_format qkv_layout = "paged_kv_" + qkv_layout cp_size = 1 @@ -7525,10 +7548,10 @@ def forward( max_seqlen_kv, key_layer.device, ) - #print('max_seqlen_q ', max_seqlen_q) - #print('max_seqlen_kv ', max_seqlen_kv) - #print('cu_seqlens_q ', cu_seqlens_q) - #print('cu_seqlens_kv ', cu_seqlens_kv) + # print('max_seqlen_q ', max_seqlen_q) + # print('max_seqlen_kv ', max_seqlen_kv) + # print('cu_seqlens_q ', cu_seqlens_q) + # print('cu_seqlens_kv ', cu_seqlens_kv) global _alibi_cache if alibi_slopes is not None: @@ -7753,9 +7776,9 @@ def forward( fp8_meta=self.fp8_meta, quantizers=self.quantizers, ) - #print('ooooooooooo ',output.shape) - #print(output[1,9,:4]) - #print(output[1,10,:4]) + # print('ooooooooooo ',output.shape) + # print(output[1,9,:4]) + # print(output[1,10,:4]) from .cpu_offload import CPUOffloadEnabled diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index a58cc4d0e9..8bf9160480 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -276,7 +276,7 @@ def fused_attn_fwd( # execute kernel - #print(max_seqlen_q, + # print(max_seqlen_q, # max_seqlen_kv, # is_training, # attn_scale, @@ -301,7 +301,7 @@ def fused_attn_fwd( # attn_bias, # rng_gen, # rng_elts_per_thread, - #) + # ) output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 25e11070fc..7a4340bd39 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -34,27 +34,18 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T /*************************************************************************************************** * Attention **************************************************************************************************/ -void reshape_q( - torch::Tensor new_q, torch::Tensor q_buffer, - torch::Tensor step_lens, - NVTE_QKV_Format qkv_format, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len); - -void reshape_o( - torch::Tensor output, torch::Tensor output_buffer, - torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned); - -void copy_to_kv_cache( - torch::Tensor new_k, torch::Tensor new_v, - torch::Tensor k_cache, torch::Tensor v_cache, - torch::Tensor page_table, - torch::Tensor step_lens, - torch::Tensor seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged); +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens, + NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, + int max_seq_len); + +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, int h_o, + int d_o, int b, int max_seq_len, bool is_output_right_aligned); + +void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor step_lens, + torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, + int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged); NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 5ae2f19c5c..6d7b741aef 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -14,12 +14,9 @@ constexpr int block_size = 512; constexpr int ctas_per_sm = 4; template -__global__ void reshape_q_kernel( - scalar_t* new_q, - scalar_t* q_buffer, - int* step_lens, - NVTE_QKV_Format qkv_format, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { +__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *step_lens, + NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, + int max_ctx_len, int max_seq_len) { // new_q: qkv_format; q_buffer: bshd // step_lens: [b + 1] if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { @@ -27,8 +24,8 @@ __global__ void reshape_q_kernel( int num_elts = step_lens[batch_idx] * h_q * d_q; int new_token_offset = batch_idx * max_ctx_len * h_q * d_q; int cache_offset = batch_idx * max_seq_len * h_q * d_q; - scalar_t* new_q_token = new_q + new_token_offset; - scalar_t* q_buffer_token = q_buffer + cache_offset; + scalar_t *new_q_token = new_q + new_token_offset; + scalar_t *q_buffer_token = q_buffer + cache_offset; for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { *(q_buffer_token + i) = *(new_q_token + i); } @@ -37,22 +34,23 @@ __global__ void reshape_q_kernel( for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int cache_offset = batch_idx * max_seq_len; for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h_q * d_q; j ++) { - *(q_buffer + (cache_offset + i) * h_q * d_q + j) = *(new_q + (i * b + batch_idx) * h_q * d_q +j); - } + for (int j = 0; j < h_q * d_q; j++) { + *(q_buffer + (cache_offset + i) * h_q * d_q + j) = + *(new_q + (i * b + batch_idx) * h_q * d_q + j); + } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int num_elts = step_lens[batch_idx] * h_q * d_q; int new_token_offset = 0; - for (int t = 0; t < batch_idx; t ++) { - new_token_offset += step_lens[t]; + for (int t = 0; t < batch_idx; t++) { + new_token_offset += step_lens[t]; } new_token_offset = new_token_offset * h_q * d_q; int cache_offset = batch_idx * max_seq_len * h_q * d_q; - scalar_t* new_q_token = new_q + new_token_offset; - scalar_t* q_buffer_token = q_buffer + cache_offset; + scalar_t *new_q_token = new_q + new_token_offset; + scalar_t *q_buffer_token = q_buffer + cache_offset; for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { *(q_buffer_token + i) = *(new_q_token + i); } @@ -61,58 +59,50 @@ __global__ void reshape_q_kernel( } template -void reshape_q_launcher( - torch::Tensor new_q, torch::Tensor q_buffer, - torch::Tensor step_lens, - NVTE_QKV_Format qkv_format, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { +void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens, + NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, + int max_seq_len) { reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_q.data_ptr()), - reinterpret_cast(q_buffer.data_ptr()), - step_lens.data_ptr(), + reinterpret_cast(new_q.data_ptr()), + reinterpret_cast(q_buffer.data_ptr()), step_lens.data_ptr(), qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); } -void reshape_q( - torch::Tensor new_q, torch::Tensor q_buffer, - torch::Tensor step_lens, - NVTE_QKV_Format qkv_format, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens, + NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, + int max_seq_len) { NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), - "new_q and q_buffer must be of the same data type."); - NVTE_CHECK( - qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_THD, - "qkv_format must be {BSHD, SBHD, THD}."); + "new_q and q_buffer must be of the same data type."); + NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_THD, + "qkv_format must be {BSHD, SBHD, THD}."); if (q_buffer.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, + max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, + max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); -// } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { -// using dtype = at::kFloat8_e4m3fn; -// reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); -// } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { -// using dtype = at::kFloat8_e5m2; -// reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, + max_seq_len); + // } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { + // using dtype = at::kFloat8_e4m3fn; + // reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); + // } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { + // using dtype = at::kFloat8_e5m2; + // reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } } template -__global__ void reshape_o_kernel( - scalar_t* output, - scalar_t* output_buffer, - int* step_lens, - int h_o, int d_o, - int b, int max_seq_len, bool is_output_right_aligned) { - // output: bshd; output_buffer: thd; +__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *step_lens, int h_o, + int d_o, int b, int max_seq_len, bool is_output_right_aligned) { + // output: bshd; output_buffer: thd; // step_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int num_elts = step_lens[batch_idx] * h_o * d_o; @@ -121,12 +111,12 @@ __global__ void reshape_o_kernel( output_offset = ((batch_idx + 1) * max_seq_len - step_lens[batch_idx]) * h_o * d_o; } int output_buffer_offset = 0; - for (int t = 0; t < batch_idx; t ++) { + for (int t = 0; t < batch_idx; t++) { output_buffer_offset += step_lens[t]; } output_buffer_offset = output_buffer_offset * h_o * d_o; - scalar_t* output_token = output + output_offset; - scalar_t* output_buffer_token = output_buffer + output_buffer_offset; + scalar_t *output_token = output + output_offset; + scalar_t *output_buffer_token = output_buffer + output_buffer_offset; for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { *(output_buffer_token + i) = *(output_token + i); } @@ -134,62 +124,56 @@ __global__ void reshape_o_kernel( } template -void reshape_o_launcher( - torch::Tensor output, torch::Tensor output_buffer, - torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { +void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(output.data_ptr()), - reinterpret_cast(output_buffer.data_ptr()), - step_lens.data_ptr(), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_buffer.data_ptr()), step_lens.data_ptr(), h_o, d_o, b, max_seq_len, is_output_right_aligned); } -void reshape_o( - torch::Tensor output, torch::Tensor output_buffer, - torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - NVTE_CHECK( - output.scalar_type() == output_buffer.scalar_type(), - "output and output_buffer must be of the same data type."); +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, int h_o, + int d_o, int b, int max_seq_len, bool is_output_right_aligned) { + NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(), + "output and output_buffer must be of the same data type."); if (output.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); } else if (output.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); } else if (output.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); -// } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { -// using dtype = at::kFloat8_e4m3fn; -// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); -// } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { -// using dtype = at::kFloat8_e5m2; -// reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); + reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); + // } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { + // using dtype = at::kFloat8_e4m3fn; + // reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); + // } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { + // using dtype = at::kFloat8_e5m2; + // reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } } template -__global__ void reindex_kv_cache_kernel( - scalar_t* k_cache, scalar_t* v_cache, - int* batch_indices, - int* step_lens, - int* seq_lens, - int h_kv, int d_k, int d_v, int b, - int max_seq_len) { +__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, + int *step_lens, int *seq_lens, int h_kv, int d_k, int d_v, + int b, int max_seq_len) { // k_cache, v_cache: bshd // batch_indices, step_lens, seq_lens: [b + 1] int actual_b = b; - for (int i = 0; i < b-1; i++) { - if (batch_indices[i+1] < batch_indices[i]) { - actual_b = i+1; + for (int i = 0; i < b - 1; i++) { + if (batch_indices[i + 1] < batch_indices[i]) { + actual_b = i + 1; } } - for (int batch_idx = 0; batch_idx < actual_b; batch_idx ++) { - for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; token_idx += gridDim.x) { + for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { + for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; + token_idx += gridDim.x) { int num_elts_k = h_kv * d_k; int num_elts_v = h_kv * d_v; int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; @@ -205,84 +189,83 @@ __global__ void reindex_kv_cache_kernel( } } if (blockIdx.x == 0) { - for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx ++) { + for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { batch_indices[batch_idx] = batch_idx; } } } template -__global__ void copy_to_kv_cache_kernel( - scalar_t* new_k, scalar_t* new_v, - scalar_t* k_cache, scalar_t* v_cache, - int* page_table, - int* step_lens, - int* seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq) { +__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, + scalar_t *v_cache, int *page_table, int *step_lens, + int *seq_lens, NVTE_QKV_Format qkv_format, int h_kv, + int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq) { int page_size = max_seq_len / max_pages_per_seq; if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int* page_list = page_table + batch_idx * max_pages_per_seq; + int *page_list = page_table + batch_idx * max_pages_per_seq; int new_token_offset = batch_idx * max_ctx_len; for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; - int token_idx = page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; - for (int j = 0; j < h_kv * d_k; j ++) { - *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (new_token_offset + i) * h_kv * d_k +j); - } - for (int j = 0; j < h_kv * d_v; j ++) { - *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (new_token_offset + i) * h_kv * d_v +j); - } + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i) / page_size]; + int token_idx = + page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = + *(new_k + (new_token_offset + i) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = + *(new_v + (new_token_offset + i) * h_kv * d_v + j); + } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int* page_list = page_table + batch_idx * max_pages_per_seq; + int *page_list = page_table + batch_idx * max_pages_per_seq; for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; - int token_idx = page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; - for (int j = 0; j < h_kv * d_k; j ++) { - *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k +j); - } - for (int j = 0; j < h_kv * d_v; j ++) { - *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v +j); - } + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i) / page_size]; + int token_idx = + page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j); + } } } } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int* page_list = page_table + batch_idx * max_pages_per_seq; + int *page_list = page_table + batch_idx * max_pages_per_seq; int new_token_offset = 0; - for (int t = 0; t < batch_idx; t ++) { - new_token_offset += step_lens[t]; + for (int t = 0; t < batch_idx; t++) { + new_token_offset += step_lens[t]; } for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i)/page_size]; - int token_idx = page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i)%page_size; - for (int j = 0; j < h_kv * d_k; j ++) { - *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (new_token_offset + i) * h_kv * d_k +j); - } - for (int j = 0; j < h_kv * d_v; j ++) { - *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (new_token_offset + i) * h_kv * d_v +j); - } + int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i) / page_size]; + int token_idx = + page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = + *(new_k + (new_token_offset + i) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = + *(new_v + (new_token_offset + i) * h_kv * d_v + j); + } } } } } template -void copy_to_kv_cache_launcher( - torch::Tensor new_k, torch::Tensor new_v, - torch::Tensor k_cache, torch::Tensor v_cache, - torch::Tensor page_table, - torch::Tensor step_lens, - torch::Tensor seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged) { +void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, + torch::Tensor step_lens, torch::Tensor seq_lens, + NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged) { // 1. new_k, new_v: qkv_format; k_cache, v_cache: bshd // 2. step_lens, seq_lens (step lens included): [b + 1] // 3. non-paged cache can be considered a special case of paged cache, @@ -291,66 +274,59 @@ void copy_to_kv_cache_launcher( // i.e. page_table = [0, 3, 1, 2] will be rearranged to [0, 1, 1, 2] // 5. assumes k_cache and v_cache have the same page_table // 6. for THD, assumes no padding between sequences in new_k and new_v - if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && - k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) { + if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && + v_cache.data_ptr() != nullptr) { if (is_non_paged) { reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - page_table.data_ptr(), - step_lens.data_ptr(), - seq_lens.data_ptr(), - h_kv, d_k, d_v, b, max_seq_len); + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + step_lens.data_ptr(), seq_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); } copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - page_table.data_ptr(), - step_lens.data_ptr(), - seq_lens.data_ptr(), - qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq); + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + step_lens.data_ptr(), seq_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, b, + max_ctx_len, max_seq_len, max_pages_per_seq); } } // copy new K/V tokens to KV cache -void copy_to_kv_cache( - torch::Tensor new_k, torch::Tensor new_v, - torch::Tensor k_cache, torch::Tensor v_cache, - torch::Tensor page_table, - torch::Tensor step_lens, - torch::Tensor seq_lens, - NVTE_QKV_Format qkv_format, - int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged) { - NVTE_CHECK( - k_cache.scalar_type() == v_cache.scalar_type() && - new_k.scalar_type() == new_v.scalar_type() && - new_k.scalar_type() == k_cache.scalar_type(), - "new_k, new_v, k_cache and v_cache must be of the same data type."); - NVTE_CHECK( - qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_THD, - "qkv_format must be {BSHD, SBHD, THD}."); +void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor step_lens, + torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, + int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged) { + NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && + new_k.scalar_type() == new_v.scalar_type() && + new_k.scalar_type() == k_cache.scalar_type(), + "new_k, new_v, k_cache and v_cache must be of the same data type."); + NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_THD, + "qkv_format must be {BSHD, SBHD, THD}."); if (k_cache.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, + seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, + seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); } else if (k_cache.scalar_type() == at::ScalarType::Float) { using dtype = float; - copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); -// } else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) { -// using dtype = at::kFloat8_e4m3fn; -// copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); -// } else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) { -// using dtype = at::kFloat8_e5m2; -// copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, + seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + // } else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) { + // using dtype = at::kFloat8_e4m3fn; + // copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + // } else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) { + // using dtype = at::kFloat8_e5m2; + // copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 63053d7ec9..b6011b1f88 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -257,7 +257,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if callables[0].training: grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if isinstance(i, torch.Tensor) and i.requires_grad), + inputs=tuple( + i + for i in static_input_surface + if isinstance(i, torch.Tensor) and i.requires_grad + ), grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), only_inputs=True, allow_unused=allow_unused_input, @@ -372,7 +376,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument with torch.cuda.graph(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if isinstance(i, torch.Tensor) and i.requires_grad), + inputs=tuple( + i + for i in static_input_surface + if isinstance(i, torch.Tensor) and i.requires_grad + ), grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, @@ -425,7 +433,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Copy values from new tensors into static tensors for i in range(len_user_args): - if isinstance(static_input_surface[i], torch.Tensor) and static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + if ( + isinstance(static_input_surface[i], torch.Tensor) + and static_input_surface[i].data_ptr() != inputs[i].data_ptr() + ): static_input_surface[i].copy_(inputs[i]) # Replay forward graph diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 0d8b660369..74a3493e19 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -15,6 +15,7 @@ from transformer_engine.pytorch.kv_cache_manager_paged import PagedKVCacheManager from transformer_engine.pytorch.kv_cache_manager_non_paged import NonPagedKVCacheManager + class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order @@ -69,7 +70,7 @@ def __init__( self.is_paged = is_paged if not self.is_paged: - cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager + cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager self.cache_manager = cls( max_batch_size=self.max_batch_size, max_seqlen=self.max_seqlen_kv, @@ -91,7 +92,7 @@ def __init__( self.max_seqlen_kv = max_seqlen_kv self.total_num_pages = total_num_pages - cls = cache_manager if cache_manager is not None else PagedKVCacheManager + cls = cache_manager if cache_manager is not None else PagedKVCacheManager self.cache_manager = cls( total_num_pages=self.total_num_pages, page_size=self.page_size, @@ -139,7 +140,7 @@ def reset(self): """ self.sequences = collections.OrderedDict() self.cache_manager.reset() - if self.input_qkv_format == 'thd': + if self.input_qkv_format == "thd": for layer_number in self.q_buffer: self.q_buffer[layer_number].fill_(0) @@ -181,7 +182,7 @@ def allocate_memory(self, layer_number: int, qkv_format: str): """ self.cache_manager.allocate_memory(layer_number) - if qkv_format == 'thd': + if qkv_format == "thd": self.q_buffer[layer_number] = torch.zeros( self.max_batch_size, self.max_ctx_len, @@ -218,16 +219,14 @@ def pre_step( seqlens_q = list(step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) - self.cu_seqlens_q.copy_( - torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu") - ) + self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu")) seq_lens = list(self.sequences.values()) - #seq_lens = [self.max_seqlen_kv] * self.batch_size + # seq_lens = [self.max_seqlen_kv] * self.batch_size cu_seqlens_kv = [0] + [sum(seq_lens[:i]) for i in range(1, actual_batch_size + 1)] - cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * (self.max_batch_size - actual_batch_size) - self.cu_seqlens_kv.copy_( - torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu") + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + self.max_batch_size - actual_batch_size ) + self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu")) def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): """ @@ -367,32 +366,45 @@ def step( # self.max_batch_size, ctx_len, self.max_ctx_len) k_cache, v_cache, page_table = self.cache_manager.step( - layer_number, k, v, self.cu_seqlens_q, self.cu_seqlens_kv, qkv_format, + layer_number, + k, + v, + self.cu_seqlens_q, + self.cu_seqlens_kv, + qkv_format, ) - return q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, self.max_seqlen_q, self.max_seqlen_kv, self.output_qkv_format + return ( + q_buffer, + k_cache, + v_cache, + page_table, + self.cu_seqlens_q, + self.cu_seqlens_kv, + self.max_seqlen_q, + self.max_seqlen_kv, + self.output_qkv_format, + ) def post_step( self, layer_number: int, output: torch.Tensor, - ): + ): """ Process the attention output in order to return it in the original qkv_format. """ if self.input_qkv_format == "bshd": - output = output[:self.batch_size, :self.max_seqlen_q].contiguous() + output = output[: self.batch_size, : self.max_seqlen_q].contiguous() if self.input_qkv_format == "sbhd": - output = output[:self.batch_size, :self.max_seqlen_q].transpose(0, 1).contiguous() + output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() if self.input_qkv_format == "thd": - print('oooo ', output.shape) - #print(output[:2, :4]) - #output_buffer = self.q_orig[layer_number] - #step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - #tex.reshape_o(output, output_buffer, step_lens, + print("oooo ", output.shape) + # print(output[:2, :4]) + # output_buffer = self.q_orig[layer_number] + # step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + # tex.reshape_o(output, output_buffer, step_lens, # self.num_heads_q, self.head_dim_q, self.batch_size, self.max_ctx_len, self.is_output_right_aligned) - #output = output_buffer.view(output_buffer.shape[0], -1) + # output = output_buffer.view(output_buffer.shape[0], -1) return output - - diff --git a/transformer_engine/pytorch/kv_cache_manager.py b/transformer_engine/pytorch/kv_cache_manager.py index 072919821f..4e9bb8353e 100644 --- a/transformer_engine/pytorch/kv_cache_manager.py +++ b/transformer_engine/pytorch/kv_cache_manager.py @@ -8,10 +8,12 @@ import torch + class KVCacheManager: """ KV cache manager. The base class for custom cache managers. """ + def __init__(self, *args, **kwargs): """Initialize the cache manager.""" self.cache = {} diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index c4f8d59a65..598595b99d 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -10,6 +10,7 @@ from transformer_engine.pytorch.kv_cache_manager import KVCacheManager from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat + class NonPagedKVCacheManager(KVCacheManager): """ The non-paged KV cache manager. @@ -72,11 +73,15 @@ def pre_step( finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_(torch.Tensor(( - unfinished_indices - + finished_indices - + list(range(prev_batch_size, self.max_batch_size)) - )).to(dtype=torch.int32, device="cpu")) + self.batch_indices.copy_( + torch.Tensor( + ( + unfinished_indices + + finished_indices + + list(range(prev_batch_size, self.max_batch_size)) + ) + ).to(dtype=torch.int32, device="cpu") + ) # Advance unfinished sequences for i in unfinished_seqs: @@ -98,7 +103,7 @@ def step( layer_number, k: torch.Tensor, v: torch.Tensor, - #step_dict: OrderedDict, + # step_dict: OrderedDict, cu_seqlens_q, cu_seqlens_kv, qkv_format: str, @@ -131,19 +136,31 @@ def step( step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] batch_size = self.max_batch_size - ctx_len=1 + ctx_len = 1 if qkv_format == "bshd": batch_size = k.shape[0] - ctx_len=k.shape[1] + ctx_len = k.shape[1] if qkv_format == "sbhd": batch_size = k.shape[1] - ctx_len=k.shape[0] + ctx_len = k.shape[0] tex.copy_to_kv_cache( - k, v, k_cache, v_cache, - self.batch_indices, step_lens, seq_lens, + k, + v, + k_cache, + v_cache, + self.batch_indices, + step_lens, + seq_lens, QKVFormat[qkv_format], - self.num_heads, self.head_dim_k, self.head_dim_v, - batch_size, ctx_len, self.max_seqlen, 1, True) + self.num_heads, + self.head_dim_k, + self.head_dim_v, + batch_size, + ctx_len, + self.max_seqlen, + 1, + True, + ) k_cache = k_cache[:batch_size] v_cache = v_cache[:batch_size] return k_cache, v_cache, None diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index 5610f4f405..d67740b613 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -47,7 +47,7 @@ def __init__( max_batch_size: int, max_seqlen: int, head_dim_v: Optional[int] = None, - #is_cuda_graph: bool = False, + # is_cuda_graph: bool = False, ): """Initialize the KV cache""" self.total_num_pages = total_num_pages @@ -59,7 +59,7 @@ def __init__( self.max_seqlen = max_seqlen self.max_pages_per_seq = max_seqlen // self.page_size self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - #self.is_cuda_graph = is_cuda_graph + # self.is_cuda_graph = is_cuda_graph # sequences contained in the kv cache, {seq_id: seq_len} self.sequences = OrderedDict() @@ -192,7 +192,7 @@ def deallocate_sequence(self, seq: int): def pre_step( self, step_dict: Dict[List, List], - ): + ): batch_size = len(step_dict) step_lens = list(step_dict.values()) cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] @@ -224,7 +224,7 @@ def step( layer_number: int, k: torch.Tensor, v: torch.Tensor, - #step_dict: OrderedDict, + # step_dict: OrderedDict, cu_seqlens_q: torch.Tensor, cu_seqlens_kv: torch.Tensor, qkv_format: str, @@ -257,19 +257,31 @@ def step( step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] batch_size = self.max_batch_size - ctx_len=1 + ctx_len = 1 if qkv_format == "bshd": batch_size = k.shape[0] - ctx_len=k.shape[1] + ctx_len = k.shape[1] if qkv_format == "sbhd": batch_size = k.shape[1] - ctx_len=k.shape[0] + ctx_len = k.shape[0] tex.copy_to_kv_cache( - k, v, k_cache, v_cache, - self.page_table, step_lens, seq_lens, + k, + v, + k_cache, + v_cache, + self.page_table, + step_lens, + seq_lens, QKVFormat[qkv_format], - self.num_heads, self.head_dim_k, self.head_dim_v, - batch_size, ctx_len, self.max_seqlen, self.max_pages_per_seq, False) + self.num_heads, + self.head_dim_k, + self.head_dim_v, + batch_size, + ctx_len, + self.max_seqlen, + self.max_pages_per_seq, + False, + ) page_table = self.page_table[:batch_size] return k_cache, v_cache, page_table From b4fbc2b3c3d466c12822dd26f947bca9e4fe9234 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 21 Feb 2025 18:13:22 -0800 Subject: [PATCH 106/239] [PyTorch] Use same API in optimizer `zero_grad` as PyTorch optimizers (#1466) Use same API in optimizer zero_grad as PyT optimizers Signed-off-by: Tim Moon --- .../pytorch/optimizers/fused_adam.py | 67 +++++++++++++------ .../pytorch/optimizers/fused_sgd.py | 59 ++++++++++++---- 2 files changed, 94 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index d972fd96ab..070f46e937 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -3,8 +3,12 @@ # See LICENSE for license information. """Fused Adam optimizer.""" +from __future__ import annotations +from collections.abc import Iterable from copy import deepcopy from itertools import chain +from typing import Optional +import warnings import torch import transformer_engine_torch as tex @@ -52,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer): params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) - bias_correction (bool, optional): apply correction factor to - moment estimates. (default: True) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve @@ -62,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer): amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) NOT SUPPORTED in FusedAdam! + bias_correction (bool, optional): apply correction factor to + moment estimates. (default: True) adam_w_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) master_weights (bool, optional): whether to maintain FP32 master weights @@ -106,15 +108,15 @@ class FusedAdam(torch.optim.Optimizer): def __init__( self, - params, - lr=1e-3, + params: Iterable[torch.nn.Parameter | dict], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + *, bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, adam_w_mode=True, - weight_decay=0.0, - amsgrad=False, - set_grad_none=True, capturable=False, master_weights=False, master_weight_dtype=torch.float32, @@ -122,6 +124,7 @@ def __init__( exp_avg_sq_dtype=torch.float32, use_decoupled_grad=False, store_param_remainders=False, + set_grad_none: Optional[bool] = None, # deprecated ): if amsgrad: @@ -160,7 +163,6 @@ def __init__( } super().__init__(params, defaults) self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none self.capturable = capturable self.master_weights = master_weights @@ -204,19 +206,46 @@ def __init__( store_param_remainders and master_weights and master_weight_dtype == torch.float32 ) - def zero_grad(self): - # pylint: disable=missing-function-docstring - if not self.use_decoupled_grad and not self.set_grad_none: - super().zero_grad() + # Deprecated options + self.set_grad_none = set_grad_none + if self.set_grad_none is not None: + warnings.warn( + "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "Use set_to_none kwarg in zero_grad instead.", + DeprecationWarning, + ) + + def zero_grad(self, set_to_none: Optional[bool] = None) -> None: + """Reset parameter gradients. + + Arguments: + set_to_none (bool, optional): whether to set grads to `None` + instead of zeroing out buffers. (default: True) + + """ + + # Handle deprecated set_grad_none option + if self.set_grad_none is not None: + if set_to_none is not None and set_to_none != self.set_grad_none: + raise ValueError( + f"Called zero_grad with set_to_none={set_to_none}, " + f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + ) + set_to_none = self.set_grad_none + if set_to_none is None: + set_to_none = True + + if not self.use_decoupled_grad and not set_to_none: + super().zero_grad(set_to_none=set_to_none) return for group in self.param_groups: for p in group["params"]: - if self.use_decoupled_grad and self.set_grad_none: + if self.use_decoupled_grad and set_to_none: p.decoupled_grad = None - elif self.use_decoupled_grad and not self.set_grad_none: + elif self.use_decoupled_grad and not set_to_none: p.decoupled_grad.zero_() - elif not self.use_decoupled_grad and self.set_grad_none: + elif not self.use_decoupled_grad and set_to_none: p.grad = None def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 53fa59821c..8a76ec5901 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -3,6 +3,11 @@ # See LICENSE for license information. """Fused SGD optimizer.""" +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional +import warnings + import torch from torch.optim.optimizer import Optimizer, required @@ -37,8 +42,8 @@ class FusedSGD(Optimizer): parameter groups lr (float): learning rate momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) Example: @@ -74,15 +79,16 @@ class FusedSGD(Optimizer): def __init__( self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, + params: Iterable[torch.nn.Parameter | dict], + lr: float | Any = required, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + *, wd_after_momentum=False, materialize_master_grads=True, - set_grad_none=False, + set_grad_none: Optional[bool] = None, # deprecated ): if lr is not required and lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") @@ -98,7 +104,7 @@ def __init__( "weight_decay": weight_decay, "nesterov": nesterov, } - if nesterov and (momentum <= 0 or dampening != 0): + if nesterov and (momentum <= 0.0 or dampening != 0.0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) @@ -106,7 +112,6 @@ def __init__( self.materialize_master_grads = materialize_master_grads self.most_recent_scale = 1.0 self.scale_set_by_backward = False - self.set_grad_none = set_grad_none # Skip buffer self._dummy_overflow_buf = torch.tensor( @@ -114,14 +119,42 @@ def __init__( ) self.multi_tensor_sgd = tex.multi_tensor_sgd + # Deprecated options + self.set_grad_none = set_grad_none + if self.set_grad_none is not None: + warnings.warn( + "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "Use set_to_none kwarg in zero_grad instead.", + DeprecationWarning, + ) + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) - def zero_grad(self): - # pylint: disable=missing-function-docstring - if self.set_grad_none: + def zero_grad(self, set_to_none: Optional[bool] = None) -> None: + """Reset parameter gradients. + + Arguments: + set_to_none (bool, optional): whether to set grads to `None` + instead of zeroing out buffers. (default: True) + + """ + + # Handle deprecated set_grad_none option + if self.set_grad_none is not None: + if set_to_none is not None and set_to_none != self.set_grad_none: + raise ValueError( + f"Called zero_grad with set_to_none={set_to_none}, " + f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + ) + set_to_none = self.set_grad_none + if set_to_none is None: + set_to_none = True + + # Reset grads + if set_to_none: for group in self.param_groups: for p in group["params"]: p.grad = None From 7f2dcf91b9c411d0bea6e56ac83edf0917d389df Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Fri, 21 Feb 2025 21:53:55 -0800 Subject: [PATCH 107/239] [Pytorch] Decoupling framework extensions from common module (#1498) * Remove dependency on transformer_engine::Tensor in attention.cu Signed-off-by: Kshitij Janardan Lakhani * Templatize thd_partition_indices_kernel and thd_read_half_tensor_kernel kernels ONLY for invoking recompilation and not directly using the pre-compiled symbols in libtransformer.so Signed-off-by: Kshitij Janardan Lakhani * Modify attention.cu for thd templatized kernels. Remove dependency on common.h Signed-off-by: Kshitij Janardan Lakhani * Move thd structs from libtransformer.so to framework extensions include header Code cleanup Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Consolidate and move thd_utils from common to framework extensions Signed-off-by: Kshitij Janardan Lakhani * Remove template decorators around thd_partition_indices_kernel and thd_read_half_tensor_kernel Signed-off-by: Kshitij Janardan Lakhani Code clean up Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Janardan Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/CMakeLists.txt | 1 - .../common/fused_attn/thd_utils.cu | 76 --------- .../pytorch/csrc/extensions/attention.cu | 68 +++++--- .../csrc/thd_utils.cuh} | 160 ++++++++++++------ 4 files changed, 147 insertions(+), 158 deletions(-) delete mode 100644 transformer_engine/common/fused_attn/thd_utils.cu rename transformer_engine/{common/fused_attn/thd_utils.h => pytorch/csrc/thd_utils.cuh} (80%) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index ed59153954..c77d230ce5 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -65,7 +65,6 @@ list(APPEND transformer_engine_SOURCES activation/swiglu.cu fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp - fused_attn/thd_utils.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu normalization/common.cpp diff --git a/transformer_engine/common/fused_attn/thd_utils.cu b/transformer_engine/common/fused_attn/thd_utils.cu deleted file mode 100644 index 17c732c530..0000000000 --- a/transformer_engine/common/fused_attn/thd_utils.cu +++ /dev/null @@ -1,76 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../cudnn_utils.h" -#include "thd_utils.h" - -namespace transformer_engine { -namespace fused_attn { - -__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, - int total_tokens, int world_size, int rank) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - int seqlen = cu_seqlens[i]; - // Currently we assume that each sequence length is divisible by (world_size*2) since we have - // to distribute each sequence evenly to different GPUs. - assert(seqlen % (world_size * 2) == 0); - cu_seqlens_s[i] = seqlen / world_size; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - - for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; - index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; - output[token_id] = index; - } -} - -__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, - int hidden_size_in_bytes, int half_idx, - int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; - int laneid = threadIdx.x % 32; - int num_warps = (blockDim.x * gridDim.x) / 32; - int num_total_tokens = cu_seqlens_s[batch]; - int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; - half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); - tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); - - for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { - int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); - - size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; - float4 *cur_half_token = - reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); - - offset_in_bytes = - (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; - float4 *cur_token = - reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); - - for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { - cur_half_token[idx] = cur_token[idx]; - } - } -} - -} // namespace fused_attn -} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index f2d1ecf3b9..1e2af2b2d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -3,12 +3,8 @@ * * See LICENSE for license information. ************************************************************************/ - -#include "common/common.h" -#include "common/fused_attn/thd_utils.h" #include "extensions.h" - -using namespace transformer_engine::fused_attn; +#include "thd_utils.cuh" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -208,28 +204,40 @@ std::vector fused_attn_fwd( std::vector output_tensors; output_tensors.push_back(o_python); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors at::Tensor output_tensor; if (nvte_aux_tensor_pack.size >= 2) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); + output_tensor = allocateSpace( + nvte_shape_to_vector(temp_shape), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); } else if (i == nvte_aux_tensor_pack.size - 2) { output_tensor = rng_state; } else if (i == nvte_aux_tensor_pack.size - 1) { output_tensor = Bias.value(); } } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; + NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); + output_tensor = + (i < nvte_aux_tensor_pack.size - 1) + ? allocateSpace( + nvte_shape_to_vector(temp_shape), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) + : rng_state; } } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); + output_tensor = allocateSpace( + nvte_shape_to_vector(temp_shape), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); } output_tensors.push_back(py::cast(output_tensor)); - tensor->data.dptr = output_tensor.data_ptr(); + NVTEBasicTensor temp_data = {output_tensor.data_ptr(), + nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), + nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; + nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } // execute the kernel @@ -425,11 +433,14 @@ std::vector fused_attn_bwd( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + auto temp_vec = std::vector(tmp.begin(), tmp.end()); + const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()}; + NVTEBasicTensor temp_data = { + Aux_CTX_Tensors[i].data_ptr(), + static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), + temp_shape}; + nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } // create dBias the same shape as Bias @@ -662,8 +673,8 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s grid_y *= tensor.size(i); } dim3 grid = {grid_x, grid_y}; - thd_read_half_tensor_kernel<<>>( + transformer_engine::fused_attn::thd_read_half_tensor_kernel<<< + grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr(), batch, hidden_size_in_bytes, half_idx, tensor.size(seq_dim)); @@ -713,13 +724,14 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; + if (lse_packed) { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, lse_seqlen, second_half_lse_seqlen); } else { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, lse_seqlen, second_half_lse_seqlen); @@ -764,13 +776,14 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; + if (lse_packed) { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, lse_seqlen, second_half_lse_seqlen); } else { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, lse_seqlen, second_half_lse_seqlen); @@ -829,13 +842,13 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ dim3 grid = {grid_x, (unsigned int)num_heads}; if (lse_packed) { - thd_out_correction_kernel + transformer_engine::fused_attn::thd_out_correction_kernel <<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); } else { - thd_out_correction_kernel + transformer_engine::fused_attn::thd_out_correction_kernel <<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, @@ -925,7 +938,8 @@ static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_p } dim3 grid = {grid_x, grid_y}; - thd_grad_correction_kernel + transformer_engine::fused_attn::thd_grad_correction_kernel <<>>( grad.data_ptr(), grad_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, hidden_size, total_tokens); @@ -992,8 +1006,8 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t constexpr unsigned int block = 256; unsigned int grid = (output.size(0) + block - 1) / block; - thd_partition_indices_kernel<<>>( + transformer_engine::fused_attn::thd_partition_indices_kernel<<< + grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( output.data_ptr(), cu_seqlens.data_ptr(), batch, total_tokens, world_size, rank); return output; diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/pytorch/csrc/thd_utils.cuh similarity index 80% rename from transformer_engine/common/fused_attn/thd_utils.h rename to transformer_engine/pytorch/csrc/thd_utils.cuh index ec265e4366..1f1f0cfdfd 100644 --- a/transformer_engine/common/fused_attn/thd_utils.h +++ b/transformer_engine/pytorch/csrc/thd_utils.cuh @@ -3,13 +3,59 @@ * * See LICENSE for license information. ************************************************************************/ +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_ -#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ -#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ - +#include #include #include +struct LseCorrectionFunctor { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +struct ReadLseFunctor { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +struct EmptyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + +#pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + namespace transformer_engine { namespace fused_attn { @@ -33,39 +79,74 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) { /*************************************************************************************************** * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ - __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, - int total_tokens, int world_size, int rank); + int total_tokens, int world_size, int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size * 2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ - __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, int hidden_size_in_bytes, int half_idx, - int dim_size_of_token); + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); -/*************************************************************************************************** - * Support THD format for Context Parallel: softmax_lse related operations - **************************************************************************************************/ + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); -struct LseCorrectionFunctor { - __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, - size_t half_idx) { - double val = lse[idx]; - float val_per_step = half_lse[half_idx]; - double max_scale = max(val, val_per_step); - double min_scale = min(val, val_per_step); - lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); - } -}; + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); -struct ReadLseFunctor { - __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, - size_t half_idx) { - half_lse[half_idx] = lse[idx]; + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); + + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } } -}; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ template __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, @@ -163,34 +244,6 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float * Support THD format for Context Parallel: Gradients correction in backward **************************************************************************************************/ -struct EmptyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} -}; - -struct CopyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { - reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; - } -}; - -template -struct AddFunctor { - __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d_ = reinterpret_cast(token)[idx]; - dtype *p_ = reinterpret_cast(&d_); - - float4 d = reinterpret_cast(token_per_step)[idx]; - dtype *p = reinterpret_cast(&d); - -#pragma unroll - for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p_[i] += p[i]; - } - - reinterpret_cast(token)[idx] = d_; - } -}; - template __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, int batch, int hidden_size, int dim_size_of_token) { @@ -246,5 +299,4 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in } // namespace fused_attn } // namespace transformer_engine - #endif From 9ec36495015f4062d8e1f91532ab3f57705aa023 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 22 Feb 2025 17:08:17 -0800 Subject: [PATCH 108/239] WIP: some cleanup Signed-off-by: Charlene Yang --- qa/L0_pytorch_unittest/test.sh | 3 +- transformer_engine/common/CMakeLists.txt | 3 +- .../include/transformer_engine/fused_attn.h | 20 +- transformer_engine/pytorch/attention.py | 4 +- .../pytorch/cpp_extensions/fused_attn.py | 26 - transformer_engine/pytorch/csrc/common.h | 2 + transformer_engine/pytorch/csrc/extensions.h | 22 +- .../pytorch/csrc/extensions/attention.cu | 485 ++++++------------ .../pytorch/csrc/extensions/pybind.cpp | 25 +- transformer_engine/pytorch/csrc/kv_cache.cuh | 158 ++++++ transformer_engine/pytorch/graph.py | 3 +- transformer_engine/pytorch/inference.py | 230 ++++----- .../pytorch/kv_cache_manager.py | 24 +- .../pytorch/kv_cache_manager_non_paged.py | 83 +-- .../pytorch/kv_cache_manager_paged.py | 97 ++-- 15 files changed, 570 insertions(+), 615 deletions(-) create mode 100644 transformer_engine/pytorch/csrc/kv_cache.cuh diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a08d6c9f90..9ca052e761 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -13,8 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py @@ -23,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 289ab19dd7..c77d230ce5 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -98,8 +98,7 @@ target_include_directories(transformer_engine PUBLIC # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas - CUDA::cudart - CUDNN::cudnn_all) + CUDA::cudart) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 2e6be8d178..62ad226962 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -157,7 +157,7 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get Q format for a given QKV layout. * - * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd. * * \return q format, e.g. sbhd. */ @@ -165,9 +165,9 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get KV format for a given QKV layout. * - * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd. * - * \return kv format, e.g. sbhd. + * \return kv format, e.g. bshd. */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); @@ -367,14 +367,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, +void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7c17290039..fe20f91b12 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Attention.""" +"""Attention""" import collections from contextlib import nullcontext from importlib.metadata import version as get_pkg_version @@ -363,7 +363,7 @@ def __eq__(self, other): } -__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] +__all__ = ["DotProductAttention", "MultiheadAttention"] def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 8bf9160480..b9810bf861 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -276,32 +276,6 @@ def fused_attn_fwd( # execute kernel - # print(max_seqlen_q, - # max_seqlen_kv, - # is_training, - # attn_scale, - # dropout, - # fast_zero_fill, - # QKVLayout[qkv_layout], - # AttnBiasType[attn_bias_type], - # AttnMaskType[attn_mask_type], - # window_size, - # cu_seqlens_q, - # cu_seqlens_kv, - # q.shape, - # k.shape, - # v.shape, - # fake_dtype, - # cu_seqlens_q_padded, - # cu_seqlens_kv_padded, - # page_table_k, - # page_table_v, - # s_quantizer, - # o_quantizer, - # attn_bias, - # rng_gen, - # rng_elts_per_thread, - # ) output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 40245cf2d9..9aa589de32 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7a4340bd39..6689c89d73 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -34,18 +34,6 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T /*************************************************************************************************** * Attention **************************************************************************************************/ -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens, - NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, - int max_seq_len); - -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned); - -void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, - torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor step_lens, - torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, - int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged); NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, @@ -82,6 +70,16 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len); +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, + int d_o, int b, int max_seq_len, bool is_output_right_aligned); +void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, + torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k, + int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged); + /*************************************************************************************************** * GEMM **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 96ca3c9420..1caf6b15e4 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -5,329 +5,11 @@ ************************************************************************/ #include "extensions.h" #include "thd_utils.cuh" +#include "kv_cache.cuh" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; -template -__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *step_lens, - NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, - int max_ctx_len, int max_seq_len) { - // new_q: qkv_format; q_buffer: bshd - // step_lens: [b + 1] - if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_q * d_q; - int new_token_offset = batch_idx * max_ctx_len * h_q * d_q; - int cache_offset = batch_idx * max_seq_len * h_q * d_q; - scalar_t *new_q_token = new_q + new_token_offset; - scalar_t *q_buffer_token = q_buffer + cache_offset; - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(q_buffer_token + i) = *(new_q_token + i); - } - } - } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int cache_offset = batch_idx * max_seq_len; - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - for (int j = 0; j < h_q * d_q; j++) { - *(q_buffer + (cache_offset + i) * h_q * d_q + j) = - *(new_q + (i * b + batch_idx) * h_q * d_q + j); - } - } - } - } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_q * d_q; - int new_token_offset = 0; - for (int t = 0; t < batch_idx; t++) { - new_token_offset += step_lens[t]; - } - new_token_offset = new_token_offset * h_q * d_q; - int cache_offset = batch_idx * max_seq_len * h_q * d_q; - scalar_t *new_q_token = new_q + new_token_offset; - scalar_t *q_buffer_token = q_buffer + cache_offset; - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(q_buffer_token + i) = *(new_q_token + i); - } - } - } -} - -template -void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens, - NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, - int max_seq_len) { - reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_q.data_ptr()), - reinterpret_cast(q_buffer.data_ptr()), step_lens.data_ptr(), - qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); -} - -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor step_lens, - NVTE_QKV_Format qkv_format, int h_q, int d_q, int b, int max_ctx_len, - int max_seq_len) { - NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), - "new_q and q_buffer must be of the same data type."); - NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_THD, - "qkv_format must be {BSHD, SBHD, THD}."); - if (q_buffer.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, - max_seq_len); - } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { - using dtype = at::BFloat16; - reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, - max_seq_len); - } else if (q_buffer.scalar_type() == at::ScalarType::Float) { - using dtype = float; - reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, - max_seq_len); - // } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { - // using dtype = at::kFloat8_e4m3fn; - // reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); - // } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { - // using dtype = at::kFloat8_e5m2; - // reshape_q_launcher(new_q, q_buffer, step_lens, qkv_format, h_q, d_q, b, max_ctx_len, max_seq_len); - } else { - NVTE_ERROR("Unsupported dtype for KV cache.\n"); - } -} - -template -__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *step_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - // output: bshd; output_buffer: thd; - // step_lens: [b + 1] - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = step_lens[batch_idx] * h_o * d_o; - int output_offset = batch_idx * max_seq_len * h_o * d_o; - if (is_output_right_aligned) { - output_offset = ((batch_idx + 1) * max_seq_len - step_lens[batch_idx]) * h_o * d_o; - } - int output_buffer_offset = 0; - for (int t = 0; t < batch_idx; t++) { - output_buffer_offset += step_lens[t]; - } - output_buffer_offset = output_buffer_offset * h_o * d_o; - scalar_t *output_token = output + output_offset; - scalar_t *output_buffer_token = output_buffer + output_buffer_offset; - for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(output_buffer_token + i) = *(output_token + i); - } - } -} - -template -void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(output.data_ptr()), - reinterpret_cast(output_buffer.data_ptr()), step_lens.data_ptr(), - h_o, d_o, b, max_seq_len, is_output_right_aligned); -} - -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor step_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(), - "output and output_buffer must be of the same data type."); - if (output.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - } else if (output.scalar_type() == at::ScalarType::BFloat16) { - using dtype = at::BFloat16; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - } else if (output.scalar_type() == at::ScalarType::Float) { - using dtype = float; - reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - // } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { - // using dtype = at::kFloat8_e4m3fn; - // reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); - // } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { - // using dtype = at::kFloat8_e5m2; - // reshape_o_launcher(output, output_buffer, step_lens, h_o, d_o, b, max_seq_len, is_output_right_aligned); - } else { - NVTE_ERROR("Unsupported dtype for KV cache.\n"); - } -} - -template -__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, - int *step_lens, int *seq_lens, int h_kv, int d_k, int d_v, - int b, int max_seq_len) { - // k_cache, v_cache: bshd - // batch_indices, step_lens, seq_lens: [b + 1] - int actual_b = b; - for (int i = 0; i < b - 1; i++) { - if (batch_indices[i + 1] < batch_indices[i]) { - actual_b = i + 1; - } - } - for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { - for (int token_idx = blockIdx.x; token_idx < seq_lens[batch_idx] - step_lens[batch_idx]; - token_idx += gridDim.x) { - int num_elts_k = h_kv * d_k; - int num_elts_v = h_kv * d_v; - int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; - int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; - int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; - int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; - for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { - *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); - } - for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { - *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); - } - } - } - if (blockIdx.x == 0) { - for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { - batch_indices[batch_idx] = batch_idx; - } - } -} - -template -__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, - scalar_t *v_cache, int *page_table, int *step_lens, - int *seq_lens, NVTE_QKV_Format qkv_format, int h_kv, - int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, - int max_pages_per_seq) { - int page_size = max_seq_len / max_pages_per_seq; - if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int *page_list = page_table + batch_idx * max_pages_per_seq; - int new_token_offset = batch_idx * max_ctx_len; - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i) / page_size]; - int token_idx = - page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i) % page_size; - for (int j = 0; j < h_kv * d_k; j++) { - *(k_cache + token_idx * h_kv * d_k + j) = - *(new_k + (new_token_offset + i) * h_kv * d_k + j); - } - for (int j = 0; j < h_kv * d_v; j++) { - *(v_cache + token_idx * h_kv * d_v + j) = - *(new_v + (new_token_offset + i) * h_kv * d_v + j); - } - } - } - } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int *page_list = page_table + batch_idx * max_pages_per_seq; - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i) / page_size]; - int token_idx = - page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i) % page_size; - for (int j = 0; j < h_kv * d_k; j++) { - *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); - } - for (int j = 0; j < h_kv * d_v; j++) { - *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j); - } - } - } - } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int *page_list = page_table + batch_idx * max_pages_per_seq; - int new_token_offset = 0; - for (int t = 0; t < batch_idx; t++) { - new_token_offset += step_lens[t]; - } - for (int i = threadIdx.x; i < step_lens[batch_idx]; i += blockDim.x) { - int page_idx = page_list[(seq_lens[batch_idx] - step_lens[batch_idx] + i) / page_size]; - int token_idx = - page_idx * page_size + (seq_lens[batch_idx] - step_lens[batch_idx] + i) % page_size; - for (int j = 0; j < h_kv * d_k; j++) { - *(k_cache + token_idx * h_kv * d_k + j) = - *(new_k + (new_token_offset + i) * h_kv * d_k + j); - } - for (int j = 0; j < h_kv * d_v; j++) { - *(v_cache + token_idx * h_kv * d_v + j) = - *(new_v + (new_token_offset + i) * h_kv * d_v + j); - } - } - } - } -} - -template -void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, - torch::Tensor v_cache, torch::Tensor page_table, - torch::Tensor step_lens, torch::Tensor seq_lens, - NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged) { - // 1. new_k, new_v: qkv_format; k_cache, v_cache: bshd - // 2. step_lens, seq_lens (step lens included): [b + 1] - // 3. non-paged cache can be considered a special case of paged cache, - // where page_table = [b, 1] and max_pages_per_seq = 1 - // 4. is_non_paged = True forces re-indexing of the cache based on page_table, - // i.e. page_table = [0, 3, 1, 2] will be rearranged to [0, 1, 1, 2] - // 5. assumes k_cache and v_cache have the same page_table - // 6. for THD, assumes no padding between sequences in new_k and new_v - if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && - v_cache.data_ptr() != nullptr) { - if (is_non_paged) { - reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), - step_lens.data_ptr(), seq_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); - } - copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), - step_lens.data_ptr(), seq_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, b, - max_ctx_len, max_seq_len, max_pages_per_seq); - } -} - -// copy new K/V tokens to KV cache -void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, - torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor step_lens, - torch::Tensor seq_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, - int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged) { - NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && - new_k.scalar_type() == new_v.scalar_type() && - new_k.scalar_type() == k_cache.scalar_type(), - "new_k, new_v, k_cache and v_cache must be of the same data type."); - NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_THD, - "qkv_format must be {BSHD, SBHD, THD}."); - if (k_cache.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, - seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, - max_seq_len, max_pages_per_seq, is_non_paged); - - } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { - using dtype = at::BFloat16; - copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, - seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, - max_seq_len, max_pages_per_seq, is_non_paged); - } else if (k_cache.scalar_type() == at::ScalarType::Float) { - using dtype = float; - copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, - seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, - max_seq_len, max_pages_per_seq, is_non_paged); - // } else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) { - // using dtype = at::kFloat8_e4m3fn; - // copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); - // } else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) { - // using dtype = at::kFloat8_e5m2; - // copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, step_lens, seq_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); - } else { - NVTE_ERROR("Unsupported dtype for KV cache.\n"); - } -} - // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, @@ -1346,3 +1028,168 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t return output; } + +/*************************************************************************************************** + * KV Cache: Reshape Q from qkv_format = thd to qkv_format = bshd + **************************************************************************************************/ + +template +void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + transformer_engine::fused_attn::reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_q.data_ptr()), + reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), + h_q, d_q, b, max_ctx_len, max_seq_len); +} + +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, + int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), + "new_q and q_buffer must be of the same data type."); + if (q_buffer.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + max_seq_len); + } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + max_seq_len); + } else if (q_buffer.scalar_type() == at::ScalarType::Float) { + using dtype = float; + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + max_seq_len); + } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { + using dtype = at::Float8_e4m3fn; + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + max_seq_len); + } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { + using dtype = at::Float8_e5m2; + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + max_seq_len); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } +} + +/*************************************************************************************************** + * KV Cache: Reshape O from qkv_format = bshd to qkv_format = thd + **************************************************************************************************/ + +template +void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { + transformer_engine::fused_attn::reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_buffer.data_ptr()), cu_new_lens.data_ptr(), + h_o, d_o, b, max_seq_len, is_output_right_aligned); +} + +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, + int d_o, int b, int max_seq_len, bool is_output_right_aligned) { + NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(), + "output and output_buffer must be of the same data type."); + if (output.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); + } else if (output.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); + } else if (output.scalar_type() == at::ScalarType::Float) { + using dtype = float; + reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); + } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { + using dtype = at::Float8_e4m3fn; + reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); + } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { + using dtype = at::Float8_e5m2; + reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, + is_output_right_aligned); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } +} + +/*************************************************************************************************** + * KV Cache: Copy new KV tokens to the KV cache + * 1. new_k and new_v are in qkv_format, and k_cache and v_cache are in 'bshd' format + * 2. cu_new_lens and cu_cached_lens are in shape, [b + 1], and cu_cached_lens are the lens after current step + * 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1], + * max_pages_per_seq = 1. Set is_non_paged = True/False accordingly. + * 4. is_non_paged = True re-indexes the cache based on the page_table, i.e. page_table = + * [[0], [3], [1], [2]] will rearrange the cache to be [[0], [1], [1], [2]]. + * 5. k_cache and v_cache should have the same page_table + * 6. For qkv_format = thd, we assume there is no padding between sequences in new_k and new_v, + * e.g. new_k = [a a a b b c], not new_k = [a a a 0..0 b b 0..0 c 0..0]. + **************************************************************************************************/ + +template +void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, + torch::Tensor cu_new_lens, torch::Tensor cu_cached_lens, + NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged) { + if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && + v_cache.data_ptr() != nullptr) { + if (is_non_paged) { + transformer_engine::fused_attn::reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); + } + transformer_engine::fused_attn::copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, b, + max_ctx_len, max_seq_len, max_pages_per_seq); + } +} + +void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, + torch::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, + int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, + bool is_non_paged) { + NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && + new_k.scalar_type() == new_v.scalar_type() && + new_k.scalar_type() == k_cache.scalar_type(), + "new_k, new_v, k_cache and v_cache must be of the same data type."); + NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_THD, + "qkv_format must be {BSHD, SBHD, THD}."); + if (k_cache.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + + } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else if (k_cache.scalar_type() == at::ScalarType::Float) { + using dtype = float; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) { + using dtype = at::Float8_e4m3fn; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) { + using dtype = at::Float8_e5m2; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index ba4f00e77f..93a86deabd 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -171,19 +171,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); - m.def("copy_to_kv_cache", ©_to_kv_cache, "Copy new KV tokens to KV cache"); - m.def("reshape_q", &reshape_q, "Reshape Q for THD before attention"); - m.def("reshape_o", &reshape_o, "Reshape O for THD after attention"); - m.def("fused_attn_fwd", &fused_attn_fwd, - "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); - m.def("fused_attn_bwd", &fused_attn_bwd, - "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); - m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", - py::call_guard()); - m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", - py::call_guard()); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, @@ -191,6 +180,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); + + // attention kernels + m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", + py::call_guard()); + m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", + py::call_guard()); + m.def("fused_attn_fwd", &fused_attn_fwd, + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); + m.def("fused_attn_bwd", &fused_attn_bwd, + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); + m.def("copy_to_kv_cache", ©_to_kv_cache, "Copy new KV tokens to KV cache"); + m.def("reshape_q", &reshape_q, "Reshape Q for THD before attention"); + m.def("reshape_o", &reshape_o, "Reshape O for THD after attention"); + // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh new file mode 100644 index 0000000000..bf343226f1 --- /dev/null +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -0,0 +1,158 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_ + +namespace transformer_engine { +namespace fused_attn { +template +__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, + int h_q, int d_q, int b, + int max_ctx_len, int max_seq_len) { + // new_q: thd; q_buffer: bshd; + // cu_new_lens: [b + 1] + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = (cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]) * h_q * d_q; + int new_token_offset = cu_new_lens[batch_idx] * h_q * d_q; + int cache_offset = batch_idx * max_seq_len * h_q * d_q; + scalar_t *new_q_token = new_q + new_token_offset; + scalar_t *q_buffer_token = q_buffer + cache_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(q_buffer_token + i) = *(new_q_token + i); + } + } +} + +template +__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, int h_o, + int d_o, int b, int max_seq_len, bool is_output_right_aligned) { + // output: bshd; output_buffer: thd; + // cu_new_lens: [b + 1] + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int num_elts = new_len * h_o * d_o; + int output_offset = batch_idx * max_seq_len * h_o * d_o; + if (is_output_right_aligned) { + output_offset = ((batch_idx + 1) * max_seq_len - new_len) * h_o * d_o; + } + int output_buffer_offset = cu_new_lens[batch_idx] * h_o * d_o; + scalar_t *output_token = output + output_offset; + scalar_t *output_buffer_token = output_buffer + output_buffer_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(output_buffer_token + i) = *(output_token + i); + } + } +} + +template +__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, + int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int d_v, + int b, int max_seq_len) { + // k_cache, v_cache: bshd + // batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1] + int actual_b = b; + for (int i = 0; i < b - 1; i++) { + if (batch_indices[i + 1] < batch_indices[i]) { + actual_b = i + 1; + } + } + for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { + int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; + token_idx += gridDim.x) { + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; + int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; + int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; + int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; + for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { + *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); + } + for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { + *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); + } + } + } + if (blockIdx.x == 0) { + for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { + batch_indices[batch_idx] = batch_idx; + } + } +} + +template +__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, + scalar_t *v_cache, int *page_table, int *cu_new_lens, + int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, + int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // cu_new_lens, cu_cached_lens: [b + 1] + // page_table: [b, max_pages_per_seq] + int page_size = max_seq_len / max_pages_per_seq; + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int *page_list = page_table + batch_idx * max_pages_per_seq; + int new_token_offset = batch_idx * max_ctx_len; + int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int page_idx = page_list[(cached_len - new_len + i) / page_size]; + int token_idx = + page_idx * page_size + (cached_len - new_len + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = + *(new_k + (new_token_offset + i) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = + *(new_v + (new_token_offset + i) * h_kv * d_v + j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int *page_list = page_table + batch_idx * max_pages_per_seq; + int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int page_idx = page_list[(cached_len - new_len + i) / page_size]; + int token_idx = + page_idx * page_size + (cached_len - new_len + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int *page_list = page_table + batch_idx * max_pages_per_seq; + int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int page_idx = page_list[(cached_len - new_len + i) / page_size]; + int token_idx = + page_idx * page_size + (cached_len - new_len + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = + *(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = + *(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j); + } + } + } + } +} +} // namespace fused_attn +} // namespace transformer_engine +#endif diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index b6011b1f88..05fa4b8010 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -69,6 +69,7 @@ def _make_graphed_callables( """ Helper method for `make_graphed_callables` """ + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): raise RuntimeError( "make_graphed_callables does not support the autocast " @@ -246,7 +247,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] - for ii in range(num_warmup_iters): + for _ in range(num_warmup_iters): hooks = [] for module in func.modules(): hook = module.register_forward_hook(hook_fn) diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 74a3493e19..d9c63f9158 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """Inference.""" -import collections +from collections import OrderedDict from typing import Dict, List from einops import rearrange @@ -16,34 +16,43 @@ from transformer_engine.pytorch.kv_cache_manager_non_paged import NonPagedKVCacheManager -class InferenceParams: # pylint: disable=too-few-public-methods +class InferenceParams: """ Inference parameters that are passed to the main model in order - to efficiently calculate and store the context and previously generated tokens - during inference. + to efficiently cache previous tokens and reuse them for the current + inference iteration. Parameters ---------- - max_batch_size : int - maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. - num_heads: int - number of attention heads in key/value tensor. + max_batch_size: int + Maximum batch size in inference + max_seqlen_kv: int + Maximum sequence length in inference + num_heads_kv: int + Number of attention heads in keys and values head_dim_k: int - head size for the key tensor. + Head size for keys dtype: torch.dtype - data type for the KV cache. - head_dim_v: Optional[int], default = None - head size for the value tensor. If None, it will be set to head_dim_k. + Data type of the KV cache + head_dim_v: int, default = None + Head size for values. If None, initialized as head_dim_k. is_paged: bool, default = False - whether the KV cache is paged or non-paged (contiguous). - total_num_pages: Optional[int], default = None - total number of pages in the K cache or V cache if is_paged = True. - page_size: Optional[int], default = None - page size in number of tokens if is_paged = True. + Whether the KV cache is paged (True) or non-paged (False) + total_num_pages: int, default = None + Total number of pages in the KV cache. Required for is_paged = True. + page_size: int, default = None + Page size of the KV cache. Required for is_paged = True. + num_heads_q: int, default = None + Number of attention heads in queries + head_dim_q: int, default = None + Head size for queries. Required for qkv_format = thd. + max_ctx_len: int, default = None + Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. + qkv_format: str, default = "bshd" + Format of the incoming query/key/value tensors in current iteration + cache_manager: KVCacheManager, default = None + Custom cache manager, with KVCacheManager as the base class. """ - def __init__( self, max_batch_size: int, @@ -105,7 +114,7 @@ def __init__( ) if qkv_format == "thd": - # query will be converted to 'bshd' to be consistent with cache format + # query is converted to 'bshd' for certain backends assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" @@ -113,32 +122,27 @@ def __init__( self.head_dim_q = head_dim_q self.max_ctx_len = max_ctx_len self.max_seqlen_q = max_ctx_len + self.q_orig = {} + self.q_buffer = {} # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - self.sequences_prev = collections.OrderedDict() - self.sequences = collections.OrderedDict() - self.step_dict = collections.OrderedDict() + self.sequences_prev = OrderedDict() + self.sequences = OrderedDict() + self.step_dict = OrderedDict() self.batch_size = 0 self.cu_seqlens_q = None self.cu_seqlens_kv = None - # original q will be used as the output buffer - self.q_orig = {} - # convert q to 'bshd' to be consistent with cache format - self.q_buffer = {} - self.is_output_right_aligned = False def reset(self): - """ - Reset the state of InferenceParams. - """ - self.sequences = collections.OrderedDict() + """Reset InferenceParams state""" + self.sequences = OrderedDict() self.cache_manager.reset() if self.input_qkv_format == "thd": for layer_number in self.q_buffer: @@ -167,31 +171,16 @@ def __repr__(self) -> str: def allocate_memory(self, layer_number: int, qkv_format: str): """ - Allocate memory for the KV cache for the layer #layer_number. - Both K cache and V cache are in 'bshd' format. - - non-paged: - - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] - - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] - - paged: - - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] - - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] - If is_cuda_graph = True, several buffers are also allocated. - - Q buffer: [max_batch_size, max_seqlen_kv, num_heads_q, head_dim_q] - - cu_seqlens_q buffer: [max_batch_size + 1] - - cu_seqlens_kv buffer: [max_batch_size + 1] + Allocate memory for the cache. For layer layer_number, + - NonPagedKVCacheManager: + - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] + - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] + - PagedKVCacheManager: + - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] + - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] """ self.cache_manager.allocate_memory(layer_number) - if qkv_format == "thd": - self.q_buffer[layer_number] = torch.zeros( - self.max_batch_size, - self.max_ctx_len, - self.num_heads_q, - self.head_dim_q, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - self.cu_seqlens_q = torch.zeros( self.max_batch_size + 1, dtype=torch.int32, @@ -205,14 +194,13 @@ def allocate_memory(self, layer_number: int, qkv_format: str): def pre_step( self, - step_dict: Dict[List, List], + step_dict: OrderedDict, ): - """ - Prepare for step(). - """ + """Update tracked sequences and prepare for step()""" self.step_dict = step_dict self.batch_size = len(step_dict) self.sequences_prev = self.sequences + self.sequences = self.cache_manager.pre_step(step_dict) actual_batch_size = len(step_dict) @@ -220,8 +208,8 @@ def pre_step( cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu")) + seq_lens = list(self.sequences.values()) - # seq_lens = [self.max_seqlen_kv] * self.batch_size cu_seqlens_kv = [0] + [sum(seq_lens[:i]) for i in range(1, actual_batch_size + 1)] cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( self.max_batch_size - actual_batch_size @@ -230,21 +218,16 @@ def pre_step( def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): """ - Convert the k cache and v cache from paged to non-paged format. This function - can be used for debugging purposes or for backends that do not have paged attention - support yet, for example, UnfusedDotProductAttention. - - It can be called after step(). Based on the page table, it re-indexes the cache - tensors and returns the contiguous, non-paged, key and value tensors. The kv cache tensors - are assumed to be in 'bshd' format (see self.allocate_memory), and the returned key and - value tensors will be in :attr:`qkv_format` to be consistent with the original inputs. + Convert k_cache and v_cache from paged to non-paged format. This is used by the + UnfusedDotProductAttention backend. Both k_cache and v_cache are assumed to be + in 'bshd' format. Parameters ---------- layer_number: int - The layer number of the kv cache + Layer number of attention in the model qkv_format: str - The format of the returned key and value tensors, {'bshd', 'sbhd', 'thd'} + Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Returns ------- @@ -257,7 +240,6 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): page_table = self.cache_manager.page_table batch_size = page_table.shape[0] actual_batch_size = len(self.step_dict) - seqlens = list(self.sequences.values()) new_k_cache = rearrange( k_cache[page_table.flatten()], "(b npages) page_size ... -> b (npages page_size) ...", @@ -268,87 +250,83 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): "(b npages) page_size ... -> b (npages page_size) ...", b=batch_size, ) + + new_k_cache = new_k_cache[:actual_batch_size].contiguous() + new_v_cache = new_v_cache[:actual_batch_size].contiguous() + if qkv_format == "sbhd": + new_k_cache = new_k_cache.transpose(0,1) + new_v_cache = new_v_cache.transpose(0,1) if qkv_format == "thd": - new_k_cache = new_k_cache.contiguous() - new_v_cache = new_v_cache.contiguous() - else: - new_k_cache = new_k_cache[:actual_batch_size].contiguous() - new_v_cache = new_v_cache[:actual_batch_size].contiguous() + assert False, "UnfusedDotProductAttention does not support qkv_format=thd." + return new_k_cache, new_v_cache def step( self, layer_number: int, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + new_q: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, qkv_format: str, ): """ - Update KV cache with the new key/value tokens for a given inference iteration. - - NonPagedKVCacheManager and PagedKVCacheManager are two examples of the cache manager. - Users can write their own cache manager with their own step() function. - - If the inference iteration has only generation sequences, :attr:`k` and :attr:`v` tensors - should have shape: - - [batch_size, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', - - [1, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and - - [batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. - - If the inference iteration has both generation sequences and context sequences, :attr:`k` - and :attr:`v` should be arranged in a way so that the sequences in generation phase come - before the sequences in context phase, in the tensor. They should have the following shape. - - [batch_size, max_seqlen, num_heads, head_dim] for :attr:`qkv_format` = 'bshd' - - [max_seqlen, batch_size, num_heads, head_dim] for :attr:`qkv_format` = 'sbhd', and - - [total_num_new_tokens, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. - Here, max_seqlen is the maximum sequence length for the new tokens in the batch, and it may - be smaller than InferenceParams.max_seqlen_kv. - - Take a batch of 4, with seq_ids = [0, 1, 2, 3], as an example. At iteration t, all 4 sequences - are processed, after which, sequence 2 is determined to be 'finished'. For iteration t+1, there - may or may not be a new sequence added to the batch. - - If no new sequence is added, input tensors :attr:`k` and :attr:`v` should have shape - [3, 1, num_heads, head_dim] for :attr:`qkv_format` = 'bshd', [1, 3, num_heads, head_dim] for - :attr:`qkv_format` = 'sbhd', and [3, num_heads, head_dim] for :attr:`qkv_format` = 'thd'. - - If one new sequence is added, for example, sequence 8 with 10 context tokens, then input tensors - :attr:`k` and :attr:`v` should be in [4, 10, num_heads, head_dim] shape if - :attr:`qkv_format` = 'bshd', [10, 4, num_heads, head_dim] if :attr:`qkv_format` = 'sbhd', - or [13, num_heads, head_dim] if :attr:`qkv_format` = 'thd'. + Copy the new KV tokens to the KV cache and reshape Q if necessary. Parameters ---------- layer_number: int - The layer number of the kv cache - k: torch.Tensor - The new key tokens for the current iteration - v: torch.Tensor - The new value tokens for the current iteration + Layer number of attention in the model + new_q: torch.Tensor + New query tokens for layer_number in current inference iteration + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration qkv_format: str - The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Returns ------- + q_buffer: torch.Tensor + new_q reshaped in order to allow certain backends to execute k_cache: torch.Tensor - The key cache tensor, containing tokens from both previous and current iterations + Full key tensor containing both previous and current key tokens v_cache: torch.Tensor - The value cache tensor, containing tokens from both previous and current iterations + Full value tensor containing both previous and current value tokens page_table: torch.Tensor - The page table if is_paged = True; else `None` + Page table for paged KV cache, [batch_size, max_pages_per_seq]. None for non-paged KV cache + cu_seqlens_q: torch.Tensor + Updated cumulative sequence lengths for query, [batch_size + 1] + cu_seqlens_kv: torch.Tensor + Updated cumulative sequence lengths for key and value, [batch_size + 1] + max_seqlen_q: int + Update maximum sequence length for query + max_seqlen_kv: int + Update maximum sequence length for key and value + qkv_format: str + Updated qkv_format, e.g. the input 'thd' format may become 'thd_2bshd' after step() """ self.input_qkv_format = qkv_format self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + if qkv_format == "thd" and layer_number not in self.q_buffer: + self.q_buffer[layer_number] = torch.zeros( + self.max_batch_size, + self.max_ctx_len, + self.num_heads_q, + self.head_dim_q, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + if qkv_format == "bshd": - q_buffer = q.contiguous() + q_buffer = new_q.contiguous() self.max_seqlen_q = q_buffer.shape[1] if qkv_format == "sbhd": - q_buffer = q.transpose(0, 1).contiguous() + q_buffer = new_q.transpose(0, 1).contiguous() self.max_seqlen_q = q_buffer.shape[1] if qkv_format == "thd": - q_buffer = q + q_buffer = new_q # self.q_orig[layer_number] = q # self.max_seqlen_q = self.max_ctx_len @@ -367,8 +345,8 @@ def step( k_cache, v_cache, page_table = self.cache_manager.step( layer_number, - k, - v, + new_k, + new_v, self.cu_seqlens_q, self.cu_seqlens_kv, qkv_format, @@ -392,7 +370,7 @@ def post_step( output: torch.Tensor, ): """ - Process the attention output in order to return it in the original qkv_format. + Process the attention output in order to return it to the original qkv_format. """ if self.input_qkv_format == "bshd": output = output[: self.batch_size, : self.max_seqlen_q].contiguous() diff --git a/transformer_engine/pytorch/kv_cache_manager.py b/transformer_engine/pytorch/kv_cache_manager.py index 4e9bb8353e..3875642efb 100644 --- a/transformer_engine/pytorch/kv_cache_manager.py +++ b/transformer_engine/pytorch/kv_cache_manager.py @@ -2,36 +2,32 @@ # # See LICENSE for license information. -"""KV Cache Manager.""" +"""KV Cache Manager""" from collections import OrderedDict -from typing import Dict, List - import torch class KVCacheManager: - """ - KV cache manager. The base class for custom cache managers. - """ + """Base KV cache manager""" def __init__(self, *args, **kwargs): - """Initialize the cache manager.""" + """Initialize cache manager""" self.cache = {} self.sequences = OrderedDict() def reset(self): - """Empty tracked sequences""" + """Reset cache manager state""" self.sequences = OrderedDict() def allocate_memory(self, layer_number: int): - """Allocate memory for the KV cache.""" + """Allocate memory for the cache""" self.cache[layer_number] = (None, None) def pre_step( self, - step_dict: Dict[List, List], + step_dict: OrderedDict, ): - """Prepare for operations in step(). Update sequences with step_dict.""" + """Update tracked sequences and prepare for step()""" return self.sequences def step( @@ -39,9 +35,9 @@ def step( layer_number: int, new_k: torch.Tensor, new_v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, + cu_new_seqlens: torch.Tensor, + cu_cached_seqlens: torch.Tensor, qkv_format: str, ): - """Update the cache with new_k and new_v tokens""" + """Copy the new tokens to KV cache""" return *self.cache[layer_number], None diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py index 598595b99d..ca6f4c225d 100644 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_non_paged.py @@ -2,9 +2,9 @@ # # See LICENSE for license information. -"""Non-Paged KV Cache Manager.""" +"""Non-Paged KV Cache Manager""" from collections import OrderedDict -from typing import Optional, Dict, List +from typing import Optional import torch import transformer_engine_torch as tex from transformer_engine.pytorch.kv_cache_manager import KVCacheManager @@ -12,9 +12,7 @@ class NonPagedKVCacheManager(KVCacheManager): - """ - The non-paged KV cache manager. - """ + """Non-paged KV cache manager""" def __init__( self, @@ -25,7 +23,7 @@ def __init__( dtype: torch.dtype, head_dim_v: Optional[int] = None, ): - """Initialize the KV cache""" + """Initialize cache manager""" self.max_batch_size = max_batch_size self.max_seqlen = max_seqlen self.num_heads = num_heads @@ -33,12 +31,15 @@ def __init__( self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - self.cache = {} + # track sequences in the cache, {seq_id: seq_len} self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # track sequence indices in the batch in order to re-index k_cache and v_cache self.batch_indices = None def allocate_memory(self, layer_number): - """Allocate memory for the KV cache""" + """Allocate memory for the cache""" k_cache = torch.zeros( self.max_batch_size, self.max_seqlen, @@ -65,9 +66,13 @@ def allocate_memory(self, layer_number): def pre_step( self, - step_dict: Dict[List, List], + step_dict: OrderedDict, ): - # Reorder cache + """Update tracked sequences and prepare for step()""" + # Track unfinished sequences' indices in the batch, e.g. + # at t-1, seq_ids = [0, 1, 2, 3], and at t, seq_ids = [0, 2, 3], because seq_id 1 finished + # batch_indices = [0, 2, 3, 1] is used in step() to re-index k_cache and v_cache so that + # they are contiguous and match the sequence indexing in q. prev_batch_size = len(self.sequences) unfinished_seqs = self.sequences.keys() & step_dict.keys() finished_seqs = self.sequences.keys() - unfinished_seqs @@ -101,56 +106,58 @@ def pre_step( def step( self, layer_number, - k: torch.Tensor, - v: torch.Tensor, - # step_dict: OrderedDict, - cu_seqlens_q, - cu_seqlens_kv, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, qkv_format: str, ): """ - Update the non-paged KV cache for a given inference iteration. - For more details, please refer to InferenceParams.update_cache(). + Copy the new tokens to the non-paged KV cache. Parameters ---------- layer_number: int - The layer number of kv cache to operate on - k: torch.Tensor - The new key tokens for the current iteration - v: torch.Tensor - The new value tokens for the current iteration - step_dict: OrderedDict - The {seq_id: step_len} information for the new inference step + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] qkv_format: str - The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Returns ------- k_cache: torch.Tensor - The key cache tensor containing previous and the current tokens + Full key tensor containing both previous and current key tokens v_cache: torch.Tensor - The value cache tensor containing previous and the current tokens + Full value tensor containing both previous and current value tokens + page_table: torch.Tensor + None for non-paged KV cache """ k_cache, v_cache = self.cache[layer_number] - step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + batch_size = self.max_batch_size ctx_len = 1 if qkv_format == "bshd": - batch_size = k.shape[0] - ctx_len = k.shape[1] + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] if qkv_format == "sbhd": - batch_size = k.shape[1] - ctx_len = k.shape[0] + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + tex.copy_to_kv_cache( - k, - v, + new_k, + new_v, k_cache, v_cache, self.batch_indices, - step_lens, - seq_lens, + cu_new_seqlens, + cu_cached_seqlens, QKVFormat[qkv_format], self.num_heads, self.head_dim_k, @@ -161,6 +168,8 @@ def step( 1, True, ) + k_cache = k_cache[:batch_size] v_cache = v_cache[:batch_size] + return k_cache, v_cache, None diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py index d67740b613..931846cce5 100644 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ b/transformer_engine/pytorch/kv_cache_manager_paged.py @@ -2,9 +2,9 @@ # # See LICENSE for license information. -"""Paged KV Cache Manager.""" +"""Paged KV Cache Manager""" from collections import defaultdict, OrderedDict -from typing import List, Optional, Dict +from typing import List, Optional import logging import torch @@ -31,11 +31,7 @@ def deallocate_page(self): class PagedKVCacheManager(KVCacheManager): - """ - Paged KV cache manager. It supports a set of utilities including adding and removing - sequences, and copying new key/value tokens to the cache. Users can overwrite this class - for more custom implementations. - """ + """Paged KV cache manager""" def __init__( self, @@ -47,9 +43,8 @@ def __init__( max_batch_size: int, max_seqlen: int, head_dim_v: Optional[int] = None, - # is_cuda_graph: bool = False, ): - """Initialize the KV cache""" + """Initialize cache manager""" self.total_num_pages = total_num_pages self.page_size = page_size self.num_heads = num_heads @@ -59,13 +54,12 @@ def __init__( self.max_seqlen = max_seqlen self.max_pages_per_seq = max_seqlen // self.page_size self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - # self.is_cuda_graph = is_cuda_graph - # sequences contained in the kv cache, {seq_id: seq_len} + # track sequences in the cache, {seq_id: seq_len} self.sequences = OrderedDict() - # kv cache, cache[layer_number] = (k_cache, v_cache) + # cache tensors, cache[layer_number] = (k_cache, v_cache) self.cache = {} - # free pages allowed to allocate, [Page(),...] + # available pages, [Page(),...] self.free_pages = [] for i in range(self.total_num_pages): self.free_pages.append(Page(i)) @@ -75,6 +69,7 @@ def __init__( self.page_table = None def reset(self): + """Reset cache manager state""" self.sequences = OrderedDict() self.free_pages = [] for i in range(self.total_num_pages): @@ -83,7 +78,7 @@ def reset(self): self.page_table.fill_(0) def allocate_memory(self, layer_number): - """Allocate memory for the KV cache""" + """Allocate memory for the cache""" k_cache = torch.empty( self.total_num_pages, self.page_size, @@ -109,8 +104,8 @@ def allocate_memory(self, layer_number): def print_cache(self): """Print KV cache status""" used_pages = [self.get_page_count(seq) for seq in self.sequences] - logger = logging.getLogger("PagedAttention") - logger.debug("cache status:") + logger = logging.getLogger("PagedKVCacheManager") + logger.debug("Cache status:") logger.debug( " total pages: %s (used %s, free %s)", self.total_num_pages, @@ -148,12 +143,6 @@ def get_page_list(self, seq: int): """Get the list of pages allocated to a sequence""" return [x.page_id for x in self.allocated_pages[seq]] - def get_page_token_offsets(self, seqlen: int): - """Get the relevant page index and token index for a given sequence length""" - page_offset = seqlen // self.page_size - token_offset = seqlen % self.page_size - return (page_offset, token_offset) - def get_page_table(self, sequences: List[int]): """Get the page table, in shape [batch_size, max_pages_per_seq]""" page_table = torch.Tensor( @@ -191,12 +180,9 @@ def deallocate_sequence(self, seq: int): def pre_step( self, - step_dict: Dict[List, List], + step_dict: OrderedDict, ): - batch_size = len(step_dict) - step_lens = list(step_dict.values()) - cu_seqlens = [0] + [sum(step_lens[:i]) for i in range(1, batch_size + 1)] - + """Update tracked sequences and prepare for step()""" # Remove finished sequences and advance unfinished sequences unfinished_seqs = self.sequences.keys() & step_dict.keys() finished_seqs = self.sequences.keys() - unfinished_seqs @@ -222,56 +208,58 @@ def pre_step( def step( self, layer_number: int, - k: torch.Tensor, - v: torch.Tensor, - # step_dict: OrderedDict, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, qkv_format: str, ): """ - Update the paged KV cache for a given inference iteration. - For more details, please refer to InferenceParams.update_cache(). + Copy the new tokens to the paged KV cache. Parameters ---------- layer_number: int - The layer number of kv cache to operate on - k: torch.Tensor - A batch of new key tokens for the current iteration - v: torch.Tensor - A batch of new value tokens for the current iteration - step_dict: OrderedDict - The {seq_id: step_len} information for the new inference step + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] qkv_format: str - The format of the new key/value tensors, {'bshd', 'sbhd', 'thd'} + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Returns ------- k_cache: torch.Tensor - The key cache tensor containing previous and the current tokens + Full key tensor containing both previous and current key tokens v_cache: torch.Tensor - The value cache tensor containing previous and the current tokens + Full value tensor containing both previous and current value tokens + page_table: torch.Tensor + Page table for current iteration, in shape [batch_size, max_pages_per_seq] """ k_cache, v_cache = self.cache[layer_number] - step_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seq_lens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + batch_size = self.max_batch_size ctx_len = 1 if qkv_format == "bshd": - batch_size = k.shape[0] - ctx_len = k.shape[1] + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] if qkv_format == "sbhd": - batch_size = k.shape[1] - ctx_len = k.shape[0] + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + tex.copy_to_kv_cache( - k, - v, + new_k, + new_v, k_cache, v_cache, self.page_table, - step_lens, - seq_lens, + cu_new_seqlens, + cu_cached_seqlens, QKVFormat[qkv_format], self.num_heads, self.head_dim_k, @@ -282,6 +270,7 @@ def step( self.max_pages_per_seq, False, ) + page_table = self.page_table[:batch_size] return k_cache, v_cache, page_table From 93235dd503a26615c7039d6abb80bcaa1ef62063 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 22 Feb 2025 22:51:37 -0800 Subject: [PATCH 109/239] WIP: all qkv_format combinations and merge CM files Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 5 +- transformer_engine/pytorch/attention.py | 50 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cu | 16 +- transformer_engine/pytorch/csrc/kv_cache.cuh | 3 +- transformer_engine/pytorch/inference.py | 611 ++++++++++++++++-- .../pytorch/kv_cache_manager.py | 43 -- .../pytorch/kv_cache_manager_non_paged.py | 175 ----- .../pytorch/kv_cache_manager_paged.py | 276 -------- 9 files changed, 570 insertions(+), 611 deletions(-) delete mode 100644 transformer_engine/pytorch/kv_cache_manager.py delete mode 100644 transformer_engine/pytorch/kv_cache_manager_non_paged.py delete mode 100644 transformer_engine/pytorch/kv_cache_manager_paged.py diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index b347340d3b..6c25a4521d 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -200,9 +200,9 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("dtype", [torch.float16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) -@pytest.mark.parametrize("qkv_format", ["thd"]) # qkv_formats) +@pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): reset_rng_states() @@ -319,6 +319,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): head_dim_q=config.head_dim_qk, max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, + allow_query_conversion=backend!="FusedAttention", ) inference_params.allocate_memory(layer_number, qkv_format) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fe20f91b12..0a41895a8a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -547,6 +547,11 @@ def get_attention_backend( "Disabling FlashAttention as KV caching requires flash-attn 2.2+, or 3.0" " (Hopper only)" ) + if use_fused_attention and pad_between_seqs: + use_fused_attention = False + logger.debug( + "Disabling FusedAttention for pad_between_seqs = True and KV caching" + ) if inference_params.is_paged: if use_fused_attention and cudnn_version < (9, 5, 0): logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+") @@ -5527,7 +5532,6 @@ def get_qkv_layout( q_format = qkv_format kv_format = qkv_format is_same_q_kv_format = True - print("qkv format", qkv_format, is_same_q_kv_format, q_format, kv_format) def run_iteratively(q, k, v): # check data pointers @@ -5616,7 +5620,6 @@ def run_iteratively(q, k, v): # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = "_".join(list([qkv_format]) * 3) - print("xxxxx0") elif ( check_strides_kv and check_shapes_kv @@ -5628,7 +5631,6 @@ def run_iteratively(q, k, v): # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = q_format + "_" + kv_format + "_" + kv_format - print("xxxxx1") else: qkv_layout = "not_supported" @@ -5932,9 +5934,8 @@ def forward( if inference_params is not None: func = flash_attn_with_kvcache fa_optional_forward_kwargs_kvcache = {} - fa_optional_forward_kwargs_kvcache["cache_seqlens"] = ( - cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - ) + cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_optional_forward_kwargs_kvcache["cache_seqlens"] = cache_seqlens fa_optional_forward_kwargs_kvcache["softmax_scale"] = self.softmax_scale fa_optional_forward_kwargs_kvcache["causal"] = "causal" in attn_mask_type if inference_params.is_paged: @@ -7506,8 +7507,6 @@ def forward( # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference - # if "padding" not in attn_mask_type: - # attn_mask_type = "padding_" + attn_mask_type if attn_mask_type in ["causal", "padding_causal"]: attn_mask_type = attn_mask_type + "_bottom_right" @@ -7523,19 +7522,7 @@ def forward( for x in [query_layer, key_layer, value_layer] ] - # reshape the query tensor - # cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but - # flash-attention and unfused attention will need q/k/v in the - # same qkv_format - # target_qkv_format = inference_params.qkv_format - # query_layer = inference_params.reshape_and_copy_q( - # query_layer, qkv_format, target_qkv_format, self.layer_number - # ) - # update KV cache and return the full key/value tensors - # full key/value tensors are in inference_params.qkv_format format - # print('query_layer',query_layer.shape, query_layer.dtype) - # print('query_layer', query_layer[8,0,:4]) ( query_layer, key_layer, @@ -7553,17 +7540,8 @@ def forward( value_layer, qkv_format, ) - # print('ssss0 ',query_layer.shape, key_layer.shape, value_layer.shape) - # print('cu_seqlens_q',cu_seqlens_q) - # print('cu_seqlens_kv',cu_seqlens_kv) - # print('maxxxxx ',max_seqlen_q, max_seqlen_kv) - - # update cu_seqlens tensors - # if inference_params.is_cuda_graph: - # cu_seqlens_q = inference_params.cu_seqlens_q_buffer - # cu_seqlens_kv = inference_params.cu_seqlens_kv_buffer - # max_seqlen_q = inference_params.max_seqlen_q - # max_seqlen_kv = inference_params.max_seqlen_kv + cu_seqlens_q_padded = None + cu_seqlens_kv_padded = None if ( isinstance(query_layer, Float8Tensor) @@ -7586,8 +7564,6 @@ def forward( ) # convert qkv layout to its corresponding paged attention layout if inference_params is not None and inference_params.is_paged: - # qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format - # qkv_layout = "paged_kv_thd_2bshd"# + qkv_format + "_2" + qkv_format qkv_layout = "paged_kv_" + qkv_layout cp_size = 1 @@ -7632,10 +7608,6 @@ def forward( max_seqlen_kv, key_layer.device, ) - # print('max_seqlen_q ', max_seqlen_q) - # print('max_seqlen_kv ', max_seqlen_kv) - # print('cu_seqlens_q ', cu_seqlens_q) - # print('cu_seqlens_kv ', cu_seqlens_kv) global _alibi_cache if alibi_slopes is not None: @@ -7860,9 +7832,6 @@ def forward( fp8_meta=self.fp8_meta, quantizers=self.quantizers, ) - # print('ooooooooooo ',output.shape) - # print(output[1,9,:4]) - # print(output[1,10,:4]) from .cpu_offload import CPUOffloadEnabled @@ -8623,6 +8592,7 @@ def forward( # pylint: disable=fixme # TODO: consider cases where sequences have different seqlens + # sequence_start = inference_params.get_seqlens_pre_step() sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6689c89d73..69dadcaf59 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -71,7 +71,7 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len); + int h_q, int d_q, int b, int max_seq_len); void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 1caf6b15e4..9d44c09714 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1035,36 +1035,36 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t template void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + int h_q, int d_q, int b, int max_seq_len) { transformer_engine::fused_attn::reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(new_q.data_ptr()), reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), - h_q, d_q, b, max_ctx_len, max_seq_len); + h_q, d_q, b, max_seq_len); } void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) { + int h_q, int d_q, int b, int max_seq_len) { NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), "new_q and q_buffer must be of the same data type."); if (q_buffer.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { using dtype = at::Float8_e4m3fn; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { using dtype = at::Float8_e5m2; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len, + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index bf343226f1..2772bae59b 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -10,8 +10,7 @@ namespace transformer_engine { namespace fused_attn { template __global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, - int h_q, int d_q, int b, - int max_ctx_len, int max_seq_len) { + int h_q, int d_q, int b, int max_seq_len) { // new_q: thd; q_buffer: bshd; // cu_new_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index d9c63f9158..ac760d685e 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -3,17 +3,54 @@ # See LICENSE for license information. """Inference.""" -from collections import OrderedDict -from typing import Dict, List +import os +from collections import OrderedDict, defaultdict +from typing import Optional, Dict, List from einops import rearrange +import logging import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -from transformer_engine.pytorch.kv_cache_manager import KVCacheManager -from transformer_engine.pytorch.kv_cache_manager_paged import PagedKVCacheManager -from transformer_engine.pytorch.kv_cache_manager_non_paged import NonPagedKVCacheManager + +__all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"] + + +class KVCacheManager: + """Base KV cache manager""" + + def __init__(self, *args, **kwargs): + """Initialize cache manager""" + self.cache = {} + self.sequences = OrderedDict() + + def reset(self): + """Reset cache manager state""" + self.sequences = OrderedDict() + + def allocate_memory(self, layer_number: int): + """Allocate memory for the cache""" + self.cache[layer_number] = (None, None) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens: torch.Tensor, + cu_cached_seqlens: torch.Tensor, + qkv_format: str, + ): + """Copy the new tokens to KV cache""" + return *self.cache[layer_number], None class InferenceParams: @@ -45,13 +82,25 @@ class InferenceParams: num_heads_q: int, default = None Number of attention heads in queries head_dim_q: int, default = None - Head size for queries. Required for qkv_format = thd. + Head size for queries. Required for qkv_format = 'thd'. max_ctx_len: int, default = None Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. qkv_format: str, default = "bshd" Format of the incoming query/key/value tensors in current iteration cache_manager: KVCacheManager, default = None Custom cache manager, with KVCacheManager as the base class. + allow_query_conversion: bool, default = True + InferenceParams only supports cache_qkv_format = 'bshd'. When qkv_format = {'sbhd', 'thd'}, + output_qkv_format = {'sbhd_2bshd', 'thd_2bshd'}, which are supported by FusedAttention but + not by FlashAttention or UnfusedDotProductAttention. + + For performance, try allow_query_conversion = False first. If it errors out with "No dot + product attention support for the provided inputs!", set allow_query_conversion = True. + + For functionality, set allow_query_conversion = True. InferenceParams converts query from + {'sbhd', 'thd'} to 'bshd', and converts the output back to {'sbhd', 'thd'}. The cost is + two transposes for qkv_format = 'sbhd', and one memory buffer (q_buffer) and two conversion + kernels (reshape_q and reshape_o) for qkv_format = 'thd'. """ def __init__( self, @@ -69,6 +118,7 @@ def __init__( max_ctx_len: int = None, qkv_format: str = "bshd", cache_manager: KVCacheManager = None, + allow_query_conversion: bool = True, ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv @@ -77,6 +127,10 @@ def __init__( self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k self.is_paged = is_paged + _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) + _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) + _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + self.allow_query_conversion = allow_query_conversion and (_NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN) if not self.is_paged: cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager @@ -114,23 +168,27 @@ def __init__( ) if qkv_format == "thd": - # query is converted to 'bshd' for certain backends - assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" - assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" - self.num_heads_q = num_heads_q - self.head_dim_q = head_dim_q self.max_ctx_len = max_ctx_len - self.max_seqlen_q = max_ctx_len - self.q_orig = {} - self.q_buffer = {} + if self.allow_query_conversion: + # query is converted to 'bshd' for certain backends + assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" + assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" + self.num_heads_q = num_heads_q + self.head_dim_q = head_dim_q + self.max_seqlen_q = max_ctx_len + self.q_orig = {} + self.q_buffer = {} # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format - self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + if self.input_qkv_format == self.cache_qkv_format or self.allow_query_conversion: + self.output_qkv_format = self.cache_qkv_format + else: + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - self.sequences_prev = OrderedDict() + self.sequences_pre = OrderedDict() self.sequences = OrderedDict() self.step_dict = OrderedDict() self.batch_size = 0 @@ -144,7 +202,7 @@ def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() self.cache_manager.reset() - if self.input_qkv_format == "thd": + if self.input_qkv_format == "thd" and self.allow_query_conversion: for layer_number in self.q_buffer: self.q_buffer[layer_number].fill_(0) @@ -192,6 +250,16 @@ def allocate_memory(self, layer_number: int, qkv_format: str): device=torch.cuda.current_device(), ) + if qkv_format == "thd" and self.allow_query_conversion: + self.q_buffer[layer_number] = torch.zeros( + self.max_batch_size, + self.max_ctx_len, + self.num_heads_q, + self.head_dim_q, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + def pre_step( self, step_dict: OrderedDict, @@ -199,9 +267,10 @@ def pre_step( """Update tracked sequences and prepare for step()""" self.step_dict = step_dict self.batch_size = len(step_dict) - self.sequences_prev = self.sequences self.sequences = self.cache_manager.pre_step(step_dict) + for k,v in enumerate(self.sequences): + self.sequences_pre[k] = self.sequences[k] - self.step_dict[k] actual_batch_size = len(step_dict) seqlens_q = list(step_dict.values()) @@ -216,6 +285,10 @@ def pre_step( ) self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu")) + def get_seqlens_pre_step(self): + """Get cached sequence lengths for current iteration before adding step_dict.values""" + return self.sequences_pre + def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): """ Convert k_cache and v_cache from paged to non-paged format. This is used by the @@ -251,13 +324,11 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): b=batch_size, ) - new_k_cache = new_k_cache[:actual_batch_size].contiguous() - new_v_cache = new_v_cache[:actual_batch_size].contiguous() - if qkv_format == "sbhd": - new_k_cache = new_k_cache.transpose(0,1) - new_v_cache = new_v_cache.transpose(0,1) - if qkv_format == "thd": - assert False, "UnfusedDotProductAttention does not support qkv_format=thd." + new_k_cache = new_k_cache.contiguous() + new_v_cache = new_v_cache.contiguous() + if qkv_format != "thd": + new_k_cache = new_k_cache[:actual_batch_size] + new_v_cache = new_v_cache[:actual_batch_size] return new_k_cache, new_v_cache @@ -307,41 +378,28 @@ def step( Updated qkv_format, e.g. the input 'thd' format may become 'thd_2bshd' after step() """ self.input_qkv_format = qkv_format - self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - - if qkv_format == "thd" and layer_number not in self.q_buffer: - self.q_buffer[layer_number] = torch.zeros( - self.max_batch_size, - self.max_ctx_len, - self.num_heads_q, - self.head_dim_q, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) + if self.input_qkv_format == self.cache_qkv_format or self.allow_query_conversion: + self.output_qkv_format = self.cache_qkv_format + else: + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + q_buffer = new_q if qkv_format == "bshd": + self.max_seqlen_q = new_q.shape[1] q_buffer = new_q.contiguous() - self.max_seqlen_q = q_buffer.shape[1] if qkv_format == "sbhd": - q_buffer = new_q.transpose(0, 1).contiguous() - self.max_seqlen_q = q_buffer.shape[1] + self.max_seqlen_q = new_q.shape[0] + if self.allow_query_conversion: + q_buffer = new_q.transpose(0, 1).contiguous() if qkv_format == "thd": - q_buffer = new_q - # self.q_orig[layer_number] = q - # self.max_seqlen_q = self.max_ctx_len - - # q_buffer = self.q_buffer[layer_number] - # step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - # ctx_len = 1 - # if qkv_format == "bshd": - # ctx_len = q.shape[1] - # if qkv_format == "sbhd": - # ctx_len = q.shape[0] - # tex.reshape_q( - # q, q_buffer, step_lens, - # QKVFormat[qkv_format], - # self.num_heads_q, self.head_dim_q, - # self.max_batch_size, ctx_len, self.max_ctx_len) + self.max_seqlen_q = self.max_ctx_len + if self.allow_query_conversion: + q_buffer = self.q_buffer[layer_number] + tex.reshape_q( + new_q, self.q_buffer[layer_number], self.cu_seqlens_q, + self.num_heads_q, self.head_dim_q, + self.max_batch_size, self.max_ctx_len) + self.q_orig[layer_number] = new_q k_cache, v_cache, page_table = self.cache_manager.step( layer_number, @@ -374,15 +432,440 @@ def post_step( """ if self.input_qkv_format == "bshd": output = output[: self.batch_size, : self.max_seqlen_q].contiguous() - if self.input_qkv_format == "sbhd": + if self.input_qkv_format == "sbhd" and self.allow_query_conversion: output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() - if self.input_qkv_format == "thd": - print("oooo ", output.shape) - # print(output[:2, :4]) - # output_buffer = self.q_orig[layer_number] - # step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - # tex.reshape_o(output, output_buffer, step_lens, - # self.num_heads_q, self.head_dim_q, self.batch_size, self.max_ctx_len, self.is_output_right_aligned) - # output = output_buffer.view(output_buffer.shape[0], -1) + if self.input_qkv_format == "thd" and self.allow_query_conversion: + output_buffer = self.q_orig[layer_number] + tex.reshape_o(output, output_buffer, self.cu_seqlens_q, + self.num_heads_q, self.head_dim_q, self.batch_size, + self.max_ctx_len, self.is_output_right_aligned) + output = output_buffer.view(output_buffer.shape[0], -1) return output + + +class NonPagedKVCacheManager(KVCacheManager): + """Non-paged KV cache manager""" + + def __init__( + self, + max_batch_size: int, + max_seqlen: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + ): + """Initialize cache manager""" + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + + # track sequences in the cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # track sequence indices in the batch in order to re-index k_cache and v_cache + self.batch_indices = None + + def allocate_memory(self, layer_number): + """Allocate memory for the cache""" + k_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + self.batch_indices = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + # Track unfinished sequences' indices in the batch, e.g. + # at t-1, seq_ids = [0, 1, 2, 3], and at t, seq_ids = [0, 2, 3], because seq_id 1 finished + # batch_indices = [0, 2, 3, 1] is used in step() to re-index k_cache and v_cache so that + # they are contiguous and match the sequence indexing in q. + prev_batch_size = len(self.sequences) + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] + finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] + self.batch_indices.copy_( + torch.Tensor( + ( + unfinished_indices + + finished_indices + + list(range(prev_batch_size, self.max_batch_size)) + ) + ).to(dtype=torch.int32, device="cpu") + ) + + # Advance unfinished sequences + for i in unfinished_seqs: + self.sequences[i] += 1 + + # Remove finished sequences + for i in finished_seqs: + self.sequences.pop(i) + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for i in new_seqs: + self.sequences[i] = step_dict[i] + + return self.sequences + + def step( + self, + layer_number, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, + qkv_format: str, + ): + """ + Copy the new tokens to the non-paged KV cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + qkv_format: str + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + page_table: torch.Tensor + None for non-paged KV cache + """ + k_cache, v_cache = self.cache[layer_number] + + batch_size = self.max_batch_size + ctx_len = 1 + if qkv_format == "bshd": + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] + if qkv_format == "sbhd": + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + + tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + self.batch_indices, + cu_new_seqlens, + cu_cached_seqlens, + QKVFormat[qkv_format], + self.num_heads, + self.head_dim_k, + self.head_dim_v, + batch_size, + ctx_len, + self.max_seqlen, + 1, + True, + ) + + k_cache = k_cache[:batch_size] + v_cache = v_cache[:batch_size] + + return k_cache, v_cache, None + + +class Page: + """A single page""" + + def __init__(self, page_id: int): + """Initialize a page""" + self.page_id = page_id + self.allocated = 0 + + def allocate_page(self): + """Allocate a page""" + self.allocated = True + + def deallocate_page(self): + """Deallocate a page""" + self.allocated = False + + +class PagedKVCacheManager(KVCacheManager): + """Paged KV cache manager""" + + def __init__( + self, + total_num_pages: int, + page_size: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + max_batch_size: int, + max_seqlen: int, + head_dim_v: Optional[int] = None, + ): + """Initialize cache manager""" + self.total_num_pages = total_num_pages + self.page_size = page_size + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.max_pages_per_seq = max_seqlen // self.page_size + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + + # track sequences in the cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # available pages, [Page(),...] + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + # allocated pages, {seq_id: [page_id,...]} + self.allocated_pages = defaultdict(list) + # page table, [batch_size, max_pages_per_seq] + self.page_table = None + + def reset(self): + """Reset cache manager state""" + self.sequences = OrderedDict() + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + self.allocated_pages = defaultdict(list) + self.page_table.fill_(0) + + def allocate_memory(self, layer_number): + """Allocate memory for the cache""" + k_cache = torch.empty( + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.empty( + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + self.page_table = torch.zeros( + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + ) + + def print_cache(self): + """Print KV cache status""" + used_pages = [self.get_page_count(seq) for seq in self.sequences] + logger = logging.getLogger("PagedKVCacheManager") + logger.debug("Cache status:") + logger.debug( + " total pages: %s (used %s, free %s)", + self.total_num_pages, + sum(used_pages), + len(self.free_pages), + ) + logger.debug(" total sequences: %s", self.get_sequence_count()) + for i, seq in enumerate(self.sequences): + logger.debug( + " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", + i, + seq, + self.get_sequence_lengths()[i], + self.get_page_count(seq), + self.get_page_list(seq), + ) + + def get_sequence_count(self): + """Get the total number of sequences in the KV cache""" + return len(self.sequences) + + def get_sequence_lengths(self): + """Get the list of sequence lengths in the KV cache""" + return list(self.sequences.values()) + + def has_free_page(self) -> bool: + """Whether the page pool has any free pages left""" + return len(self.free_pages) > 0 + + def get_page_count(self, seq: int): + """Get the number of pages allocated to a sequence""" + return len(self.allocated_pages[seq]) + + def get_page_list(self, seq: int): + """Get the list of pages allocated to a sequence""" + return [x.page_id for x in self.allocated_pages[seq]] + + def get_page_table(self, sequences: List[int]): + """Get the page table, in shape [batch_size, max_pages_per_seq]""" + page_table = torch.Tensor( + [ + self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) + for seq in sequences + ] + ).to(dtype=torch.int32, device="cpu") + self.page_table[: self.get_sequence_count()].copy_(page_table) + return self.page_table + + def allocate_page(self, seq: int): + """Allocate a new page to a sequence""" + if not self.has_free_page(): + raise RuntimeError("KV cache is full!") + page = self.free_pages.pop(0) + page.allocate_page() + self.allocated_pages[seq].append(page) + + def allocate_sequence(self, seq: int, context_len: int): + """Add a new sequence to the cache""" + num_pages = context_len // self.page_size + if context_len % self.page_size > 0: + num_pages = num_pages + 1 + for _ in range(num_pages): + self.allocate_page(seq) + + def deallocate_sequence(self, seq: int): + """Deallocate all the pages for a sequence""" + for page in self.allocated_pages[seq]: + page.deallocate_page() + if not page.allocated: + self.free_pages.append(page) + self.allocated_pages.pop(seq) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + # Remove finished sequences and advance unfinished sequences + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + for seq in finished_seqs: + self.sequences.pop(seq) + self.deallocate_sequence(seq) + for seq in unfinished_seqs: + if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: + self.allocate_page(seq) + self.sequences[seq] += 1 + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for seq in new_seqs: + self.sequences[seq] = step_dict[seq] + self.allocate_sequence(seq, step_dict[seq]) + + # Get page table + self.page_table = self.get_page_table(list(self.sequences.keys())) + + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, + qkv_format: str, + ): + """ + Copy the new tokens to the paged KV cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + qkv_format: str + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + page_table: torch.Tensor + Page table for current iteration, in shape [batch_size, max_pages_per_seq] + """ + k_cache, v_cache = self.cache[layer_number] + + batch_size = self.max_batch_size + ctx_len = 1 + if qkv_format == "bshd": + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] + if qkv_format == "sbhd": + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + + tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + self.page_table, + cu_new_seqlens, + cu_cached_seqlens, + QKVFormat[qkv_format], + self.num_heads, + self.head_dim_k, + self.head_dim_v, + batch_size, + ctx_len, + self.max_seqlen, + self.max_pages_per_seq, + False, + ) + + page_table = self.page_table[:batch_size] + + return k_cache, v_cache, page_table diff --git a/transformer_engine/pytorch/kv_cache_manager.py b/transformer_engine/pytorch/kv_cache_manager.py deleted file mode 100644 index 3875642efb..0000000000 --- a/transformer_engine/pytorch/kv_cache_manager.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""KV Cache Manager""" -from collections import OrderedDict -import torch - - -class KVCacheManager: - """Base KV cache manager""" - - def __init__(self, *args, **kwargs): - """Initialize cache manager""" - self.cache = {} - self.sequences = OrderedDict() - - def reset(self): - """Reset cache manager state""" - self.sequences = OrderedDict() - - def allocate_memory(self, layer_number: int): - """Allocate memory for the cache""" - self.cache[layer_number] = (None, None) - - def pre_step( - self, - step_dict: OrderedDict, - ): - """Update tracked sequences and prepare for step()""" - return self.sequences - - def step( - self, - layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_new_seqlens: torch.Tensor, - cu_cached_seqlens: torch.Tensor, - qkv_format: str, - ): - """Copy the new tokens to KV cache""" - return *self.cache[layer_number], None diff --git a/transformer_engine/pytorch/kv_cache_manager_non_paged.py b/transformer_engine/pytorch/kv_cache_manager_non_paged.py deleted file mode 100644 index ca6f4c225d..0000000000 --- a/transformer_engine/pytorch/kv_cache_manager_non_paged.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Non-Paged KV Cache Manager""" -from collections import OrderedDict -from typing import Optional -import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch.kv_cache_manager import KVCacheManager -from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat - - -class NonPagedKVCacheManager(KVCacheManager): - """Non-paged KV cache manager""" - - def __init__( - self, - max_batch_size: int, - max_seqlen: int, - num_heads: int, - head_dim_k: int, - dtype: torch.dtype, - head_dim_v: Optional[int] = None, - ): - """Initialize cache manager""" - self.max_batch_size = max_batch_size - self.max_seqlen = max_seqlen - self.num_heads = num_heads - self.head_dim_k = head_dim_k - self.dtype = dtype - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - - # track sequences in the cache, {seq_id: seq_len} - self.sequences = OrderedDict() - # cache tensors, cache[layer_number] = (k_cache, v_cache) - self.cache = {} - # track sequence indices in the batch in order to re-index k_cache and v_cache - self.batch_indices = None - - def allocate_memory(self, layer_number): - """Allocate memory for the cache""" - k_cache = torch.zeros( - self.max_batch_size, - self.max_seqlen, - self.num_heads, - self.head_dim_k, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - v_cache = torch.zeros( - self.max_batch_size, - self.max_seqlen, - self.num_heads, - self.head_dim_v, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - self.cache[layer_number] = (k_cache, v_cache) - - self.batch_indices = torch.zeros( - self.max_batch_size, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - - def pre_step( - self, - step_dict: OrderedDict, - ): - """Update tracked sequences and prepare for step()""" - # Track unfinished sequences' indices in the batch, e.g. - # at t-1, seq_ids = [0, 1, 2, 3], and at t, seq_ids = [0, 2, 3], because seq_id 1 finished - # batch_indices = [0, 2, 3, 1] is used in step() to re-index k_cache and v_cache so that - # they are contiguous and match the sequence indexing in q. - prev_batch_size = len(self.sequences) - unfinished_seqs = self.sequences.keys() & step_dict.keys() - finished_seqs = self.sequences.keys() - unfinished_seqs - unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] - finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( - torch.Tensor( - ( - unfinished_indices - + finished_indices - + list(range(prev_batch_size, self.max_batch_size)) - ) - ).to(dtype=torch.int32, device="cpu") - ) - - # Advance unfinished sequences - for i in unfinished_seqs: - self.sequences[i] += 1 - - # Remove finished sequences - for i in finished_seqs: - self.sequences.pop(i) - - # Add new sequences - new_seqs = step_dict.keys() - self.sequences.keys() - for i in new_seqs: - self.sequences[i] = step_dict[i] - - return self.sequences - - def step( - self, - layer_number, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_new_seqlens, - cu_cached_seqlens, - qkv_format: str, - ): - """ - Copy the new tokens to the non-paged KV cache. - - Parameters - ---------- - layer_number: int - Layer number of attention in the model - new_k: torch.Tensor - New key tokens for layer_number in current inference iteration - new_v: torch.Tensor - New value tokens for layer_number in current inference iteration - cu_new_seqlens: torch.Tensor - Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] - cu_cached_seqlens: torch.Tensor - Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] - qkv_format: str - Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - Full key tensor containing both previous and current key tokens - v_cache: torch.Tensor - Full value tensor containing both previous and current value tokens - page_table: torch.Tensor - None for non-paged KV cache - """ - k_cache, v_cache = self.cache[layer_number] - - batch_size = self.max_batch_size - ctx_len = 1 - if qkv_format == "bshd": - batch_size = new_k.shape[0] - ctx_len = new_k.shape[1] - if qkv_format == "sbhd": - batch_size = new_k.shape[1] - ctx_len = new_k.shape[0] - - tex.copy_to_kv_cache( - new_k, - new_v, - k_cache, - v_cache, - self.batch_indices, - cu_new_seqlens, - cu_cached_seqlens, - QKVFormat[qkv_format], - self.num_heads, - self.head_dim_k, - self.head_dim_v, - batch_size, - ctx_len, - self.max_seqlen, - 1, - True, - ) - - k_cache = k_cache[:batch_size] - v_cache = v_cache[:batch_size] - - return k_cache, v_cache, None diff --git a/transformer_engine/pytorch/kv_cache_manager_paged.py b/transformer_engine/pytorch/kv_cache_manager_paged.py deleted file mode 100644 index 931846cce5..0000000000 --- a/transformer_engine/pytorch/kv_cache_manager_paged.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Paged KV Cache Manager""" -from collections import defaultdict, OrderedDict -from typing import List, Optional -import logging - -import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch.kv_cache_manager import KVCacheManager -from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat - - -class Page: - """A single page""" - - def __init__(self, page_id: int): - """Initialize a page""" - self.page_id = page_id - self.allocated = 0 - - def allocate_page(self): - """Allocate a page""" - self.allocated = True - - def deallocate_page(self): - """Deallocate a page""" - self.allocated = False - - -class PagedKVCacheManager(KVCacheManager): - """Paged KV cache manager""" - - def __init__( - self, - total_num_pages: int, - page_size: int, - num_heads: int, - head_dim_k: int, - dtype: torch.dtype, - max_batch_size: int, - max_seqlen: int, - head_dim_v: Optional[int] = None, - ): - """Initialize cache manager""" - self.total_num_pages = total_num_pages - self.page_size = page_size - self.num_heads = num_heads - self.head_dim_k = head_dim_k - self.dtype = dtype - self.max_batch_size = max_batch_size - self.max_seqlen = max_seqlen - self.max_pages_per_seq = max_seqlen // self.page_size - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - - # track sequences in the cache, {seq_id: seq_len} - self.sequences = OrderedDict() - # cache tensors, cache[layer_number] = (k_cache, v_cache) - self.cache = {} - # available pages, [Page(),...] - self.free_pages = [] - for i in range(self.total_num_pages): - self.free_pages.append(Page(i)) - # allocated pages, {seq_id: [page_id,...]} - self.allocated_pages = defaultdict(list) - # page table, [batch_size, max_pages_per_seq] - self.page_table = None - - def reset(self): - """Reset cache manager state""" - self.sequences = OrderedDict() - self.free_pages = [] - for i in range(self.total_num_pages): - self.free_pages.append(Page(i)) - self.allocated_pages = defaultdict(list) - self.page_table.fill_(0) - - def allocate_memory(self, layer_number): - """Allocate memory for the cache""" - k_cache = torch.empty( - self.total_num_pages, - self.page_size, - self.num_heads, - self.head_dim_k, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - v_cache = torch.empty( - self.total_num_pages, - self.page_size, - self.num_heads, - self.head_dim_v, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - self.cache[layer_number] = (k_cache, v_cache) - - self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" - ) - - def print_cache(self): - """Print KV cache status""" - used_pages = [self.get_page_count(seq) for seq in self.sequences] - logger = logging.getLogger("PagedKVCacheManager") - logger.debug("Cache status:") - logger.debug( - " total pages: %s (used %s, free %s)", - self.total_num_pages, - sum(used_pages), - len(self.free_pages), - ) - logger.debug(" total sequences: %s", self.get_sequence_count()) - for i, seq in enumerate(self.sequences): - logger.debug( - " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", - i, - seq, - self.get_sequence_lengths()[i], - self.get_page_count(seq), - self.get_page_list(seq), - ) - - def get_sequence_count(self): - """Get the total number of sequences in the KV cache""" - return len(self.sequences) - - def get_sequence_lengths(self): - """Get the list of sequence lengths in the KV cache""" - return list(self.sequences.values()) - - def has_free_page(self) -> bool: - """Whether the page pool has any free pages left""" - return len(self.free_pages) > 0 - - def get_page_count(self, seq: int): - """Get the number of pages allocated to a sequence""" - return len(self.allocated_pages[seq]) - - def get_page_list(self, seq: int): - """Get the list of pages allocated to a sequence""" - return [x.page_id for x in self.allocated_pages[seq]] - - def get_page_table(self, sequences: List[int]): - """Get the page table, in shape [batch_size, max_pages_per_seq]""" - page_table = torch.Tensor( - [ - self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) - for seq in sequences - ] - ).to(dtype=torch.int32, device="cpu") - self.page_table[: self.get_sequence_count()].copy_(page_table) - return self.page_table - - def allocate_page(self, seq: int): - """Allocate a new page to a sequence""" - if not self.has_free_page(): - raise RuntimeError("KV cache is full!") - page = self.free_pages.pop(0) - page.allocate_page() - self.allocated_pages[seq].append(page) - - def allocate_sequence(self, seq: int, context_len: int): - """Add a new sequence to the cache""" - num_pages = context_len // self.page_size - if context_len % self.page_size > 0: - num_pages = num_pages + 1 - for _ in range(num_pages): - self.allocate_page(seq) - - def deallocate_sequence(self, seq: int): - """Deallocate all the pages for a sequence""" - for page in self.allocated_pages[seq]: - page.deallocate_page() - if not page.allocated: - self.free_pages.append(page) - self.allocated_pages.pop(seq) - - def pre_step( - self, - step_dict: OrderedDict, - ): - """Update tracked sequences and prepare for step()""" - # Remove finished sequences and advance unfinished sequences - unfinished_seqs = self.sequences.keys() & step_dict.keys() - finished_seqs = self.sequences.keys() - unfinished_seqs - for seq in finished_seqs: - self.sequences.pop(seq) - self.deallocate_sequence(seq) - for seq in unfinished_seqs: - if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: - self.allocate_page(seq) - self.sequences[seq] += 1 - - # Add new sequences - new_seqs = step_dict.keys() - self.sequences.keys() - for seq in new_seqs: - self.sequences[seq] = step_dict[seq] - self.allocate_sequence(seq, step_dict[seq]) - - # Get page table - self.page_table = self.get_page_table(list(self.sequences.keys())) - - return self.sequences - - def step( - self, - layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_new_seqlens, - cu_cached_seqlens, - qkv_format: str, - ): - """ - Copy the new tokens to the paged KV cache. - - Parameters - ---------- - layer_number: int - Layer number of attention in the model - new_k: torch.Tensor - New key tokens for layer_number in current inference iteration - new_v: torch.Tensor - New value tokens for layer_number in current inference iteration - cu_new_seqlens: torch.Tensor - Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] - cu_cached_seqlens: torch.Tensor - Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] - qkv_format: str - Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - Full key tensor containing both previous and current key tokens - v_cache: torch.Tensor - Full value tensor containing both previous and current value tokens - page_table: torch.Tensor - Page table for current iteration, in shape [batch_size, max_pages_per_seq] - """ - k_cache, v_cache = self.cache[layer_number] - - batch_size = self.max_batch_size - ctx_len = 1 - if qkv_format == "bshd": - batch_size = new_k.shape[0] - ctx_len = new_k.shape[1] - if qkv_format == "sbhd": - batch_size = new_k.shape[1] - ctx_len = new_k.shape[0] - - tex.copy_to_kv_cache( - new_k, - new_v, - k_cache, - v_cache, - self.page_table, - cu_new_seqlens, - cu_cached_seqlens, - QKVFormat[qkv_format], - self.num_heads, - self.head_dim_k, - self.head_dim_v, - batch_size, - ctx_len, - self.max_seqlen, - self.max_pages_per_seq, - False, - ) - - page_table = self.page_table[:batch_size] - - return k_cache, v_cache, page_table From b476244abd411d2832d178794360c0f2701ea7fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Feb 2025 18:02:39 +0000 Subject: [PATCH 110/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- .../include/transformer_engine/fused_attn.h | 14 ++-- transformer_engine/pytorch/attention.py | 4 +- transformer_engine/pytorch/csrc/extensions.h | 8 +- .../pytorch/csrc/extensions/attention.cu | 73 ++++++++++--------- transformer_engine/pytorch/csrc/kv_cache.cuh | 45 ++++++------ transformer_engine/pytorch/inference.py | 31 ++++++-- 7 files changed, 93 insertions(+), 84 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 6c25a4521d..88da489387 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -319,7 +319,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): head_dim_q=config.head_dim_qk, max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, - allow_query_conversion=backend!="FusedAttention", + allow_query_conversion=backend != "FusedAttention", ) inference_params.allocate_memory(layer_number, qkv_format) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 62ad226962..3c7b3f5817 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -367,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0a41895a8a..a59be84985 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -549,9 +549,7 @@ def get_attention_backend( ) if use_fused_attention and pad_between_seqs: use_fused_attention = False - logger.debug( - "Disabling FusedAttention for pad_between_seqs = True and KV caching" - ) + logger.debug("Disabling FusedAttention for pad_between_seqs = True and KV caching") if inference_params.is_paged: if use_fused_attention and cudnn_version < (9, 5, 0): logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+") diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 69dadcaf59..3f25b95356 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -70,10 +70,10 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_seq_len); -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned); +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, + int d_q, int b, int max_seq_len); +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9d44c09714..ee0412e3aa 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ #include "extensions.h" -#include "thd_utils.cuh" #include "kv_cache.cuh" +#include "thd_utils.cuh" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -1036,36 +1036,32 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t template void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, int d_q, int b, int max_seq_len) { - transformer_engine::fused_attn::reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_q.data_ptr()), - reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), - h_q, d_q, b, max_seq_len); + transformer_engine::fused_attn:: + reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_q.data_ptr()), + reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), + h_q, d_q, b, max_seq_len); } -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_seq_len) { +void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, + int d_q, int b, int max_seq_len) { NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), "new_q and q_buffer must be of the same data type."); if (q_buffer.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { using dtype = at::Float8_e4m3fn; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { using dtype = at::Float8_e5m2; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, - max_seq_len); + reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } @@ -1076,16 +1072,18 @@ void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new **************************************************************************************************/ template -void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - transformer_engine::fused_attn::reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(output.data_ptr()), - reinterpret_cast(output_buffer.data_ptr()), cu_new_lens.data_ptr(), - h_o, d_o, b, max_seq_len, is_output_right_aligned); +void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, + torch::Tensor cu_new_lens, int h_o, int d_o, int b, int max_seq_len, + bool is_output_right_aligned) { + transformer_engine::fused_attn:: + reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_buffer.data_ptr()), + cu_new_lens.data_ptr(), h_o, d_o, b, max_seq_len, is_output_right_aligned); } -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned) { +void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(), "output and output_buffer must be of the same data type."); if (output.scalar_type() == at::ScalarType::Half) { @@ -1136,18 +1134,21 @@ void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch:: if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) { if (is_non_paged) { - transformer_engine::fused_attn::reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), - cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); + transformer_engine::fused_attn:: + reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + page_table.data_ptr(), cu_new_lens.data_ptr(), + cu_cached_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); } - transformer_engine::fused_attn::copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_k.data_ptr()), - reinterpret_cast(new_v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), - cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, b, - max_ctx_len, max_seq_len, max_pages_per_seq); + transformer_engine::fused_attn:: + copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, + b, max_ctx_len, max_seq_len, max_pages_per_seq); } } diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index 2772bae59b..295bbb207f 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -9,12 +9,12 @@ namespace transformer_engine { namespace fused_attn { template -__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, - int h_q, int d_q, int b, int max_seq_len) { +__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, int h_q, + int d_q, int b, int max_seq_len) { // new_q: thd; q_buffer: bshd; // cu_new_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = (cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]) * h_q * d_q; + int num_elts = (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]) * h_q * d_q; int new_token_offset = cu_new_lens[batch_idx] * h_q * d_q; int cache_offset = batch_idx * max_seq_len * h_q * d_q; scalar_t *new_q_token = new_q + new_token_offset; @@ -26,12 +26,13 @@ __global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_ne } template -__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, int h_o, - int d_o, int b, int max_seq_len, bool is_output_right_aligned) { +__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, + int h_o, int d_o, int b, int max_seq_len, + bool is_output_right_aligned) { // output: bshd; output_buffer: thd; // cu_new_lens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int num_elts = new_len * h_o * d_o; int output_offset = batch_idx * max_seq_len * h_o * d_o; if (is_output_right_aligned) { @@ -48,8 +49,8 @@ __global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int template __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, - int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int d_v, - int b, int max_seq_len) { + int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, + int d_v, int b, int max_seq_len) { // k_cache, v_cache: bshd // batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1] int actual_b = b; @@ -59,10 +60,9 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; - for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; - token_idx += gridDim.x) { + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; + for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) { int num_elts_k = h_kv * d_k; int num_elts_v = h_kv * d_v; int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; @@ -98,12 +98,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = page_table + batch_idx * max_pages_per_seq; int new_token_offset = batch_idx * max_ctx_len; - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int page_idx = page_list[(cached_len - new_len + i) / page_size]; - int token_idx = - page_idx * page_size + (cached_len - new_len + i) % page_size; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (new_token_offset + i) * h_kv * d_k + j); @@ -117,12 +116,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = page_table + batch_idx * max_pages_per_seq; - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int page_idx = page_list[(cached_len - new_len + i) / page_size]; - int token_idx = - page_idx * page_size + (cached_len - new_len + i) % page_size; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); } @@ -134,12 +132,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int *page_list = page_table + batch_idx * max_pages_per_seq; - int cached_len = cu_cached_lens[batch_idx+1] - cu_cached_lens[batch_idx]; - int new_len = cu_new_lens[batch_idx+1] - cu_new_lens[batch_idx]; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int page_idx = page_list[(cached_len - new_len + i) / page_size]; - int token_idx = - page_idx * page_size + (cached_len - new_len + i) % page_size; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index ac760d685e..26ad1cbf90 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -102,6 +102,7 @@ class InferenceParams: two transposes for qkv_format = 'sbhd', and one memory buffer (q_buffer) and two conversion kernels (reshape_q and reshape_o) for qkv_format = 'thd'. """ + def __init__( self, max_batch_size: int, @@ -130,7 +131,9 @@ def __init__( _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - self.allow_query_conversion = allow_query_conversion and (_NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN) + self.allow_query_conversion = allow_query_conversion and ( + _NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN + ) if not self.is_paged: cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager @@ -269,7 +272,7 @@ def pre_step( self.batch_size = len(step_dict) self.sequences = self.cache_manager.pre_step(step_dict) - for k,v in enumerate(self.sequences): + for k, v in enumerate(self.sequences): self.sequences_pre[k] = self.sequences[k] - self.step_dict[k] actual_batch_size = len(step_dict) @@ -396,9 +399,14 @@ def step( if self.allow_query_conversion: q_buffer = self.q_buffer[layer_number] tex.reshape_q( - new_q, self.q_buffer[layer_number], self.cu_seqlens_q, - self.num_heads_q, self.head_dim_q, - self.max_batch_size, self.max_ctx_len) + new_q, + self.q_buffer[layer_number], + self.cu_seqlens_q, + self.num_heads_q, + self.head_dim_q, + self.max_batch_size, + self.max_ctx_len, + ) self.q_orig[layer_number] = new_q k_cache, v_cache, page_table = self.cache_manager.step( @@ -436,9 +444,16 @@ def post_step( output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() if self.input_qkv_format == "thd" and self.allow_query_conversion: output_buffer = self.q_orig[layer_number] - tex.reshape_o(output, output_buffer, self.cu_seqlens_q, - self.num_heads_q, self.head_dim_q, self.batch_size, - self.max_ctx_len, self.is_output_right_aligned) + tex.reshape_o( + output, + output_buffer, + self.cu_seqlens_q, + self.num_heads_q, + self.head_dim_q, + self.batch_size, + self.max_ctx_len, + self.is_output_right_aligned, + ) output = output_buffer.view(output_buffer.shape[0], -1) return output From 3cb001d2cdaf141a9da2c11118e39d59184cb3ce Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sun, 23 Feb 2025 20:42:26 -0800 Subject: [PATCH 111/239] WIP: some lint fixes Signed-off-by: Charlene Yang --- .../fused_attn_f16_arbitrary_seqlen.cu | 12 ++++----- .../jax/csrc/extensions/attention.cpp | 10 ++++--- transformer_engine/pytorch/attention.py | 24 +++++++---------- transformer_engine/pytorch/inference.py | 26 ++++++++++--------- 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index c3a650f251..102d44359f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -380,7 +380,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { - size_t count = 2 * ((size_t)is_ragged_q + (size_t)is_ragged_kv); + size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); if (is_ragged_q && cudnn_runtime_version >= 90600) { seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { @@ -440,13 +440,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsV = nullptr; if (is_ragged_kv) { devOffsetsK = - static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; + static_cast(devOffsets) + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsets) + - ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -833,7 +833,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { - size_t count = 2 * ((size_t)is_ragged_q + (size_t)is_ragged_kv); + size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); if (is_ragged_q && cudnn_runtime_version >= 90600) { seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { @@ -901,13 +901,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsV = nullptr; if (is_ragged_kv) { devOffsetsK = - static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; + static_cast(devOffsets) + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsets) + - ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 7447cd1871..83db39426c 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); + auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); @@ -164,7 +165,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), + ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); @@ -173,7 +174,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { @@ -252,6 +253,7 @@ static void FusedAttnForwardImpl( backend, softmax_aux); /* Call the underlying NVTE API */ + auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); @@ -269,7 +271,7 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { @@ -283,7 +285,7 @@ static void FusedAttnForwardImpl( s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index a59be84985..b3d7142d93 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -32,7 +32,6 @@ fused_attn_fwd, fused_attn_bwd, QKVLayout, - QKVFormat, AttnBiasType, AttnMaskType, FusedAttnBackend, @@ -1058,6 +1057,7 @@ def get_attention_backend( @torch.no_grad() def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): + """Convert cu_seqlens to attention_mask""" seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) @@ -5611,24 +5611,15 @@ def run_iteratively(q, k, v): check_strides_kv and check_shapes_kv and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) - and is_same_q_kv_format ): # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd # three chunks of memory, q, k and v, which may be disjoint or consecutive, and # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True - qkv_layout = "_".join(list([qkv_format]) * 3) - elif ( - check_strides_kv - and check_shapes_kv - and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) - and not is_same_q_kv_format - ): - # sbhd_bshd_bshd, bshd_sbhd_sbhd, thd_bshd_bshd, thd_sbhd_sbhd - # three chunks of memory, q, k and v, which may be disjoint or consecutive, and - # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or - # check_ptrs_qk=True or check_ptrs_kv=True - qkv_layout = q_format + "_" + kv_format + "_" + kv_format + if is_same_q_kv_format: + qkv_layout = "_".join(list([qkv_format]) * 3) + else: + qkv_layout = q_format + "_" + kv_format + "_" + kv_format else: qkv_layout = "not_supported" @@ -5930,7 +5921,10 @@ def forward( fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] if inference_params is not None: - func = flash_attn_with_kvcache + if _flash_attn_2_2_plus: + func = flash_attn_with_kvcache + if _use_flash_attn_3: + func = flash_attn_with_kvcache_v3 fa_optional_forward_kwargs_kvcache = {} cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] fa_optional_forward_kwargs_kvcache["cache_seqlens"] = cache_seqlens diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 26ad1cbf90..a935374cd2 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -4,10 +4,10 @@ """Inference.""" import os +import logging from collections import OrderedDict, defaultdict -from typing import Optional, Dict, List +from typing import Optional, List from einops import rearrange -import logging import torch @@ -20,7 +20,7 @@ class KVCacheManager: """Base KV cache manager""" - def __init__(self, *args, **kwargs): + def __init__(self): """Initialize cache manager""" self.cache = {} self.sequences = OrderedDict() @@ -35,7 +35,7 @@ def allocate_memory(self, layer_number: int): def pre_step( self, - step_dict: OrderedDict, + step_dict: OrderedDict, # pylint: disable=unused-argument ): """Update tracked sequences and prepare for step()""" return self.sequences @@ -43,11 +43,11 @@ def pre_step( def step( self, layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_new_seqlens: torch.Tensor, - cu_cached_seqlens: torch.Tensor, - qkv_format: str, + new_k: torch.Tensor, # pylint: disable=unused-argument + new_v: torch.Tensor, # pylint: disable=unused-argument + cu_new_seqlens: torch.Tensor, # pylint: disable=unused-argument + cu_cached_seqlens: torch.Tensor, # pylint: disable=unused-argument + qkv_format: str, # pylint: disable=unused-argument ): """Copy the new tokens to KV cache""" return *self.cache[layer_number], None @@ -206,8 +206,8 @@ def reset(self): self.sequences = OrderedDict() self.cache_manager.reset() if self.input_qkv_format == "thd" and self.allow_query_conversion: - for layer_number in self.q_buffer: - self.q_buffer[layer_number].fill_(0) + for _, q_buffer in self.q_buffer.items(): + q_buffer.fill_(0) def __repr__(self) -> str: if self.is_paged: @@ -273,7 +273,7 @@ def pre_step( self.sequences = self.cache_manager.pre_step(step_dict) for k, v in enumerate(self.sequences): - self.sequences_pre[k] = self.sequences[k] - self.step_dict[k] + self.sequences_pre[k] = v - self.step_dict[k] actual_batch_size = len(step_dict) seqlens_q = list(step_dict.values()) @@ -471,6 +471,7 @@ def __init__( dtype: torch.dtype, head_dim_v: Optional[int] = None, ): + super().__init__() """Initialize cache manager""" self.max_batch_size = max_batch_size self.max_seqlen = max_seqlen @@ -654,6 +655,7 @@ def __init__( max_seqlen: int, head_dim_v: Optional[int] = None, ): + super().__init__() """Initialize cache manager""" self.total_num_pages = total_num_pages self.page_size = page_size From 583b76f639ea5941458c07b89ac57e2870cd53d8 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sun, 23 Feb 2025 22:15:15 -0800 Subject: [PATCH 112/239] WIP: add docstring for IP Signed-off-by: Charlene Yang --- transformer_engine/pytorch/attention.py | 10 +++--- transformer_engine/pytorch/inference.py | 48 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b3d7142d93..fcf6ac45ef 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -514,11 +514,11 @@ def get_attention_backend( use_unfused_attention = False # Filter: KV cache - # backend | non-paged/paged | precision - # --------------------------------------------------------------------------------- - # FlashAttention | non-paged/paged | FP16/BF16 - # FusedAttention | non-paged/paged | FP16/BF16 (non-paged/paged), FP8 (non-paged) - # UnfusedDotProductAttention | non-paged/paged | FP32/FP16/BF16 + # backend | precision + # ------------------------------------------------------------------------- + # FlashAttention | FP16/BF16 (non-paged/paged) + # FusedAttention | FP16/BF16 (non-paged/paged), FP8 (non-paged) + # UnfusedDotProductAttention | FP32/FP16/BF16 (non-paged/paged) if inference_params is not None: if context_parallel: logger.debug( diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index a935374cd2..43ac09dbf6 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -59,6 +59,54 @@ class InferenceParams: to efficiently cache previous tokens and reuse them for the current inference iteration. + A typical KV caching workflow is as follows.:: + + modules = [TransformerLayer() for _ in range(num_layers)] + model = torch.nn.Sequential(*modules) + inference_params = InferenceParams(max_batch_size, max_seqlen_kv, ...) + for i in range(inference_iterations): + # seq_ids = [0, 2, 3] + # step_lens = [10, 1, 1] + # step_dict = OrderedDict(zip(seq_ids, step_lens)) + inference_params.pre_step(step_dict) + output = model( + ..., + inference_params=inference_params, + attn_mask_type="padding_causal", + ) + # assume qkv_format = "bshd" + if inference_params.is_output_right_aligned: + output = output[:,-1] + else: + output = output[:,step_dict.values()] + + + The memory allocation and copies of the new KV tokens to KV cache take place + in the following locations.:: + + class TransformerLayer: + class MultiHeadAttention: + if self.layer_number not in inference_params: + inference_params.allocate_memory(self.layer_number) + class DotProductAttention: + if inference_params is not None: + q, k_cache, v_cache, qkv_format = inference_params.step( + new_q, new_k, new_v, qkv_format) + output = attention(q, k_cache, v_cache, new_qkv_format) + if inference_params is not None: + output = inference_params.post_step(output) + return output + + InferenceParams supports cache_qkv_format = "bshd" only, and the step() function may + change qkv_format depending on the attention backend. + + Backend | Before step() | After step() + ------------------------------------------------------------------------------------ + FusedAttention | {bshd, sbhd, thd} | {bshd_2bshd, sbhd_2bshd, thd_2bshd} + FlashAttention | {bshd, sbhd, thd} | {bshd, sbhd, thd} + UnfusedDotProductAttention | {bshd, sbhd, thd} | {bshd, sbhd, bshd} + + Parameters ---------- max_batch_size: int From d668f18f4a2b93c92d9d63f8f6d14ab3f075ec0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 24 Feb 2025 14:50:49 +0100 Subject: [PATCH 113/239] [Pytorch] Added missing assert_dim_for_fp8_exec for Linear * fix Signed-off-by: Pawel Gadzinski * reshape inp Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e51513630f..bae21eebfd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -27,6 +27,7 @@ divide, init_method_constant, non_tn_fp8_gemm_supported, + assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, requires_grad, @@ -118,13 +119,14 @@ def forward( # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication nvtx_range_push(f"{nvtx_label}.input_cast_comm") - inputmat = inp + inputmat = inp.view(-1, in_features) inputmat_total = None with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) own_quantized_input = False if fp8: + assert_dim_for_fp8_exec(inputmat, weight) if ( any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not FP8GlobalStateManager.get_fp8_recipe().delayed() From 229dd04537abebfea998d56fe518a9a9bf70b483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 24 Feb 2025 14:57:51 +0100 Subject: [PATCH 114/239] [PyTorch] Run all Python tests, even if one of them fails * non-exit tests Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 34 ++++++++++++---------- qa/L1_pytorch_distributed_unittest/test.sh | 17 ++++++----- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index dd7f95bce0..6915d618f0 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -2,22 +2,26 @@ # # See LICENSE for license information. -set -e : ${TE_PATH:=/opt/transformerengine} pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py -pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py -pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py -NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py -pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py -pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py -pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py -pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + +FAIL=0 + +pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1 +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 + +exit $FAIL diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 8ee0be1af5..5e3823d85c 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -2,14 +2,17 @@ # # See LICENSE for license information. -set -e - : ${TE_PATH:=/opt/transformerengine} pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py + +FAIL=0 + +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1 # pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1 + +exit $FAIL From f13b86173a674ca6fd6035902455c2f3c5675784 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 24 Feb 2025 16:55:44 -0800 Subject: [PATCH 115/239] fix sequences_pre Signed-off-by: Charlene Yang --- transformer_engine/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 43ac09dbf6..67dec8c376 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -320,7 +320,8 @@ def pre_step( self.batch_size = len(step_dict) self.sequences = self.cache_manager.pre_step(step_dict) - for k, v in enumerate(self.sequences): + self.sequences_pre = OrderedDict() + for k, v in self.sequences.items(): self.sequences_pre[k] = v - self.step_dict[k] actual_batch_size = len(step_dict) From a06d72c247aab3759727ce33e885c2dbd2d4f582 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 00:58:14 +0000 Subject: [PATCH 116/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- qa/L0_pytorch_unittest/test.sh | 2 +- .../fused_attn_f16_arbitrary_seqlen.cu | 14 ++++---- .../jax/csrc/extensions/attention.cpp | 32 ++++++++++--------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 438ab3d8fd..0c7a907051 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -25,4 +25,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || FAIL=1 -exit $FAIL \ No newline at end of file +exit $FAIL diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 102d44359f..e12122f822 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -439,14 +439,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsK = nullptr; void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = - static_cast(devOffsets) + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsets) + - (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -900,14 +901,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsK = nullptr; void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = - static_cast(devOffsets) + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsets) + - (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 83db39426c..4b64a113ab 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -165,16 +165,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), - nullptr); + ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { @@ -271,9 +272,10 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -281,13 +283,13 @@ static void FusedAttnForwardImpl( auto q_tensor = TensorWrapper(q, q_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } From 87441885351bd3ae9f485bf81819d40166ccb043 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 26 Feb 2025 03:09:21 +0800 Subject: [PATCH 117/239] Minor fixes for attention (#1504) * minor fixes for attention Signed-off-by: Charlene Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 6 +++--- transformer_engine/pytorch/attention.py | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 01151a50db..13c99ae244 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -153,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // special conditions for blackwell // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && @@ -238,12 +238,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700))))) && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d6b9894fc3..7666d3f32b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -118,7 +118,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_version = PkgVersion("0") _flash_attn_version_required = PkgVersion("2.1.1") _flash_attn_version_required_blackwell = PkgVersion("2.7.3") -_flash_attn_max_version = PkgVersion("2.7.3") +_flash_attn_max_version = PkgVersion("2.7.4.post1") _flash_attn_2_plus = False _flash_attn_2_1_plus = False _flash_attn_2_3_plus = False @@ -507,13 +507,16 @@ def get_attention_backend( if use_flash_attention and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) + or ( + head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) ): if _flash_attn_is_installed: logger.debug( "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90). " + "head_dim_qk <= 256 (>192 requires sm80/90/100+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", head_dim_qk, head_dim_v, From 9351a179b47fba83472fc38504f0793de5850a6d Mon Sep 17 00:00:00 2001 From: guyueh1 <140554423+guyueh1@users.noreply.github.com> Date: Tue, 25 Feb 2025 14:09:40 -0800 Subject: [PATCH 118/239] Fix a crash in NeMo 2.0 during module._apply(lambda t: t.cpu()) (#1502) * Fix a crash with module._apply(lambda t: t.cpu()) Signed-off-by: Guyue Huang * Add comments Signed-off-by: Guyue Huang * Make sure tensor is moved to dst device before quantizer quantizes Signed-off-by: Guyue Huang --------- Signed-off-by: Guyue Huang Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/tensor/float8_tensor.py | 2 ++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index da788182a0..989959817a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -484,6 +484,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) # Just copy FP8 data if other tensor is Float8Tensor if isinstance(tensor, Float8Tensor): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 86b13415a1..6e3835fbef 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -368,6 +368,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) # Just copy FP8 data if other tensor is MXFP8Tensor if isinstance(tensor, MXFP8Tensor): From 94c929192200b729089d1feda2d0cd6b6c65d621 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Tue, 25 Feb 2025 14:31:56 -0800 Subject: [PATCH 119/239] Adding remove_caches API to Float8Tensor class (#1425) * add remove_caches api Signed-off-by: Youngeun Kwon * Update transformer_engine/pytorch/tensor/float8_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Youngeun Kwon * explicit delete Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Youngeun Kwon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/tensor/float8_tensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 989959817a..49bf4facfa 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -334,6 +334,14 @@ def _reset_caches(self) -> None: """ self._transpose_invalid = True + def remove_caches(self) -> None: + """ + Remove transpose cache and mark it as invalid. + """ + self._transpose_invalid = True + del self._transpose # explicitly deletes the data for safety + self._transpose = None + def clear(self): """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" self._data = torch.Tensor() if self._data is not None else None From 8ca2caf8a40e9eea6ddb84a16e2dd6f7aa9bdac6 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Tue, 25 Feb 2025 17:32:47 -0800 Subject: [PATCH 120/239] Parallel Cross Entropy using online softmax (#1456) * Added parallel cross entropy loss implementation using online softmax Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added tests Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added reshape of loss output Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added to test list Signed-off-by: Selvaraj Anandaraj * Added Triton dependency Signed-off-by: Selvaraj Anandaraj * Added copyright Signed-off-by: Selvaraj Anandaraj * Fixed lint errors Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Selvaraj Anandaraj * Fixed lint and triton failure Signed-off-by: Selvaraj Anandaraj * Removed flattening for scalars Signed-off-by: Selvaraj Anandaraj * Skip tests on Blackwell due to TE CI caveat Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added reason arg Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not register Triton dependency with setuptools Signed-off-by: Tim Moon --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Signed-off-by: Tim Moon Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- docs/api/pytorch.rst | 2 + qa/L0_pytorch_unittest/test.sh | 1 + setup.py | 2 + tests/pytorch/test_parallel_cross_entropy.py | 108 ++++++ transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/cross_entropy.py | 72 ++++ .../pytorch/triton/cross_entropy.py | 341 ++++++++++++++++++ 7 files changed, 527 insertions(+) create mode 100644 tests/pytorch/test_parallel_cross_entropy.py create mode 100644 transformer_engine/pytorch/cross_entropy.py create mode 100644 transformer_engine/pytorch/triton/cross_entropy.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 4154a18598..67a123d334 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -54,6 +54,8 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index +.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy + .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs .. autoapifunction:: transformer_engine.pytorch.initialize_ub diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 6915d618f0..870e869795 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -22,6 +22,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 exit $FAIL diff --git a/setup.py b/setup.py index 1d9818458e..856c518f79 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,8 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch"]) + # Blackwell is not supported as of Triton 3.2.0, need custom internal build + # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py new file mode 100644 index 0000000000..5e355dc989 --- /dev/null +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import random +import pytest +import torch +from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy + + +class TestParallelCrossEntropy: + + def generate_iters(self, iters: int): + self.iters = iters + + def generate_infra(self, reduce_loss: bool, label_smoothing: float): + self.test_loss_func = parallel_cross_entropy + self.ref_loss_func = torch.nn.CrossEntropyLoss( + label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" + ) + + def generate_input(self, dtype: torch.dtype, swap_dim: bool): + + SQ = random.choice([64, 128]) + batch = random.choice([1, 2]) + vocab = random.choice([64000, 128000]) + + if swap_dim: + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() + self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + else: + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() + self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + + self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) + self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + + def one_iteration_test( + self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool + ): + + self.generate_input(dtype, swap_dim) + + self.input_test.requires_grad_(True) + self.input_ref.requires_grad_(True) + + test_loss = self.test_loss_func( + self.input_test, self.tar_test, label_smoothing, reduce_loss, None + ) + if reduce_loss: + test_loss.backward() + + ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) + if reduce_loss: + ref_loss.backward() + + test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss + + torch.testing.assert_close(test_loss, ref_loss, check_dtype=False) + if reduce_loss: + torch.testing.assert_close( + torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad + ) + + self.input_test = None + self.input_ref = None + self.tar_test = None + self.tar_ref = None + + def test_float32_input(self): + self.generate_iters(5) + self.generate_infra(True, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=True + ) + + def test_bfloat16_input(self): + self.generate_iters(5) + self.generate_infra(True, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.bfloat16, swap_dim=False, label_smoothing=0, reduce_loss=True + ) + + def test_swapped_input(self): + self.generate_iters(5) + self.generate_infra(True, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=True, label_smoothing=0, reduce_loss=True + ) + + def test_label_smoothing(self): + self.generate_iters(3) + self.generate_infra(True, 0.1) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=False, label_smoothing=0.1, reduce_loss=True + ) + + def test_non_reduced_loss(self): + self.generate_iters(1) + self.generate_infra(False, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False + ) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index d424b97f74..92250cd322 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -89,6 +89,7 @@ def _load_library(): from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers +from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py new file mode 100644 index 0000000000..e5da32164d --- /dev/null +++ b/transformer_engine/pytorch/cross_entropy.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Cross Entropy Loss API""" + +import torch + +import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy + +__all__ = [ + "parallel_cross_entropy", +] + + +class CrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the + loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted + to the dataype of the input. + """ + + @staticmethod + def forward( + ctx, _input, target, label_smoothing=0.0, reduce_loss=False, dist_process_group=None + ): + """ + The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each + distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. + target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. + dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. + + Returns: + tensor: The computed loss. + """ + loss, _input = triton_cross_entropy.cross_entropy_forward( + _input, target, label_smoothing, reduce_loss, dist_process_group + ) + + ctx.save_for_backward(_input.detach()) + return loss + + @staticmethod + def backward(ctx, grad_output): + """ + The backward pass of the Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + (_input,) = ctx.saved_tensors + _input = triton_cross_entropy.cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + ) + + +parallel_cross_entropy = CrossEntropyFunction.apply diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py new file mode 100644 index 0000000000..43a3100926 --- /dev/null +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -0,0 +1,341 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient Cross Entropy kernels written with OpenAI Triton.""" + +from typing import Union +from functools import reduce +from operator import mul + +import torch +import torch.distributed as dist + +import triton +import triton.language as tl + + +@triton.jit +def online_softmax_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes the m/d components on this TP rank for the online softmax. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride (int): The stride of the m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) + else: + X_y = float("-inf") + else: + X_y = float("-inf") + + m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( + tl.float32 + ) + block_max = tl.max(X_block) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + tl.store(m_d_X_y_ptr, m) + tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) + tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) + + +@triton.jit +def cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + loss_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + world_size, + n_cols, + n_non_ignore, + label_smoothing: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + loss_stride (int): The stride of the loss tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride: The stride of m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + world_size (int): The size of world involved in this distributed loss calculation. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + loss_ptr += program_id * loss_stride + m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride + + # Need to reduce the m/d/X_y values from other TP ranks + m = tl.load(m_d_X_y_ptr) + d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) + ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) + + for i in range(1, world_size): + offset = i * 3 * n_non_ignore * m_d_X_y_stride + access_ptr = m_d_X_y_ptr + offset + m_new = tl.load(access_ptr) + d_new = tl.load(access_ptr + m_d_X_y_stride) + X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) + + d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) + m = tl.maximum(m, m_new) + ori_X_y = tl.maximum(ori_X_y, X_y_new) + + # Label smoothing is a general case of normal cross entropy + scaled_x_sum = 0.0 + eps = label_smoothing / (n_cols * world_size) + + # 4. [Online softmax] second pass: calculate the gradients + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + grad_dtype = X_block.dtype + X_block = X_block.to(tl.float32) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + loss = -(ori_X_y - m - tl.log(d)) + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx) + X_y += -(1 - label_smoothing) / (n_non_ignore) + tl.store(X_ptr + y - vocab_start_idx, X_y) + + tl.store(loss_ptr, loss) + + +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) + + +def cross_entropy_forward( + _input: torch.Tensor, + target: torch.Tensor, + label_smoothing: float, + reduce_loss: bool, + dist_process_group: Union[dist.ProcessGroup, None], +): + """Forward implementation of Cross Entropy kernel""" + + B, SQ, V = _input.shape + n_rows = B * SQ + + assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID." + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device) + + # tensor to hold this rank's m/d/X_y values + m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device) + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + + online_softmax_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + m_d_X_y_ptr=m_d_X_y, + m_d_X_y_stride=m_d_X_y.stride(-1), + rank=rank, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if world_size > 1: + m_d_X_y_gathered = torch.zeros( + n_rows * 3 * world_size, dtype=torch.float32, device=_input.device + ) + dist.all_gather_into_tensor(m_d_X_y_gathered, m_d_X_y, group=dist_process_group) + else: + m_d_X_y_gathered = m_d_X_y + + cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), + loss_ptr=loss_1d, + loss_stride=loss_1d.stride(-1), + m_d_X_y_ptr=m_d_X_y_gathered, + m_d_X_y_stride=m_d_X_y_gathered.stride(-1), + rank=rank, + world_size=world_size, + n_cols=V, + n_non_ignore=n_rows, + label_smoothing=label_smoothing, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) + + return loss, _input + + +def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): + """Backward implementation of cross entropy loss kernel""" + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + + else: + B, SQ, V = _input.shape + n_rows = B * SQ + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + return _input From 5d85857a018e53976137e2cc94b9788a9c320b52 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Wed, 26 Feb 2025 02:36:08 +0100 Subject: [PATCH 121/239] Added memory alignment check to cast_fp8_1D (#1507) * Added TMA alignment check to cast_fp8_1D Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use tensor const-ref instead of tensor const-ptr Signed-off-by: Tim Moon --------- Signed-off-by: Oleg Goncharov Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/common/common.cu | 8 +------- transformer_engine/common/common.h | 14 ++++++++++++-- transformer_engine/common/util/cast_kernels.cuh | 9 +++++++-- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index cbeec66958..c3a556edba 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -67,11 +67,6 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { return dtypeMapping.at(dtype); } -inline bool isPointerAligned(const void *const ptr, const int alignment) { - const uint64_t ptr_as_uint = reinterpret_cast(ptr); - return ptr_as_uint % alignment == 0; -} - // Set up parameters to create TMA descriptor. void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, @@ -100,8 +95,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); - constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address - NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment), + NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), "Tensor data pointer must be 16B aligned"); const int TMA_needed_size = TMA_gmem_alignment / type_size; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ca9103532d..46eb248156 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -426,6 +427,17 @@ constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +// Alignment requirements for the Tensor Memory Accelerator (TMA) +constexpr int TMA_gmem_alignment = 16; // global memory address alignment + +inline bool is_aligned_ptr(const void *ptr, size_t alignment) { + return reinterpret_cast(ptr) % alignment == 0; +} + +inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { + return is_aligned_ptr(static_cast(t.data.dptr), alignment); +} + size_t typeToSize(const DType type); void CheckNoopTensor(const Tensor &t, const std::string &name); @@ -465,8 +477,6 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); -inline bool isPointerAligned(const void *const ptr, const int alignment); - // Set up parameters to create TMA descriptor. void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 404babc745..d1ede8d98d 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1110,7 +1110,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype())) { + if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_gmem_alignment) && + is_aligned_tensor_data(*output, TMA_gmem_alignment)) { // Aligned AND FP8 cast_fp8_1D(input, output, stream); } else { @@ -1118,7 +1120,10 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons CastVectorizedUnaryKernelLauncher(input, noop, output, stream); } } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype())) { + if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_gmem_alignment) && + is_aligned_tensor_data(*output, TMA_gmem_alignment) && + is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { // Aligned AND FP8 (+dAct) cast_fp8_2D(input, act_input, output, dbias, workspace, stream); From 2834e4ad83b86fff086b2882737e2c55f99d994f Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 25 Feb 2025 18:33:46 -0800 Subject: [PATCH 122/239] [PyTorch] Skip context parallelism tests if not enough GPUs (#1508) * Skip context parallelism tests if not enough GPUs Signed-off-by: Tim Moon * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../fused_attn/test_fused_attn_with_cp.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 9866591e8d..85950347ba 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -3,9 +3,10 @@ # See LICENSE for license information. import os -import pytest import subprocess -from test_fused_attn import ModelConfig + +import pytest +import torch from transformer_engine.pytorch.attention import ( _flash_attn_2_plus, _flash_attn_2_3_plus, @@ -15,6 +16,8 @@ get_cudnn_version, ) +from test_fused_attn import ModelConfig + model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA @@ -58,6 +61,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + config = model_configs_flash_attn[model] if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") @@ -77,7 +84,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( - num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, + num_gpus_per_node=num_gpus, dtype=dtype, model=model, qkv_format=qkv_format, @@ -115,6 +122,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("fp8_mha", [False, True]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): @@ -155,7 +166,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha subprocess.run( get_bash_arguments( - num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, + num_gpus_per_node=num_gpus, dtype=dtype, model=model, qkv_format=qkv_format, From 9b33071006780ad0caed15478e89989b97a7ceb0 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Wed, 26 Feb 2025 21:59:04 -0800 Subject: [PATCH 123/239] WIP: minor fixes for multi-layer Signed-off-by: Charlene Yang --- transformer_engine/pytorch/attention.py | 2 +- .../pytorch/csrc/extensions/attention.cu | 2 +- transformer_engine/pytorch/csrc/kv_cache.cuh | 19 ++++------ transformer_engine/pytorch/graph.py | 35 ++++++++----------- 4 files changed, 23 insertions(+), 35 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fcf6ac45ef..933eaf1eb8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8411,7 +8411,7 @@ def forward( # Pre-allocate memory for key-values for inference # ================================================= - if inference_params is not None and self.layer_number not in inference_params.layer_numbers: + if inference_params is not None and self.layer_number not in inference_params.cache_manager.cache: inference_params.allocate_memory(self.layer_number, self.qkv_format) # ====================== diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index ee0412e3aa..ab34e35d22 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1148,7 +1148,7 @@ void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch:: reinterpret_cast(k_cache.data_ptr()), reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, - b, max_ctx_len, max_seq_len, max_pages_per_seq); + b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); } } diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index 295bbb207f..6209db18e4 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -77,11 +77,6 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } } - if (blockIdx.x == 0) { - for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { - batch_indices[batch_idx] = batch_idx; - } - } } template @@ -89,19 +84,19 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar scalar_t *v_cache, int *page_table, int *cu_new_lens, int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, - int max_pages_per_seq) { + int max_pages_per_seq, bool is_non_paged) { // new_k, new_v: qkv_format; k_cache, v_cache: bshd // cu_new_lens, cu_cached_lens: [b + 1] // page_table: [b, max_pages_per_seq] int page_size = max_seq_len / max_pages_per_seq; if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int *page_list = page_table + batch_idx * max_pages_per_seq; + int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int new_token_offset = batch_idx * max_ctx_len; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { - int page_idx = page_list[(cached_len - new_len + i) / page_size]; + int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = @@ -115,11 +110,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar } } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int *page_list = page_table + batch_idx * max_pages_per_seq; + int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { - int page_idx = page_list[(cached_len - new_len + i) / page_size]; + int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); @@ -131,11 +126,11 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar } } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int *page_list = page_table + batch_idx * max_pages_per_seq; + int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; for (int i = threadIdx.x; i < new_len; i += blockDim.x) { - int page_idx = page_list[(cached_len - new_len + i) / page_size]; + int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; for (int j = 0; j < h_kv * d_k; j++) { *(k_cache + token_idx * h_kv * d_k + j) = diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 05fa4b8010..42b83dfb53 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -258,11 +258,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if callables[0].training: grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple( - i - for i in static_input_surface - if isinstance(i, torch.Tensor) and i.requires_grad - ), + inputs=tuple(i for i in static_input_surface if i.requires_grad), grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), only_inputs=True, allow_unused=allow_unused_input, @@ -321,22 +317,23 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) - with torch.cuda.graph(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, - retain_graph=retain_graph_in_backward, - ) + if callables[0].training: + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, + ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if arg.requires_grad: + if callables[0].training and isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: @@ -377,11 +374,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument with torch.cuda.graph(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple( - i - for i in static_input_surface - if isinstance(i, torch.Tensor) and i.requires_grad - ), + inputs=tuple(i for i in static_input_surface if i.requires_grad), grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, @@ -393,7 +386,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if isinstance(arg, torch.Tensor) and arg.requires_grad: + if callables[0].training and isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: From e3de9bcaaf30ca22a430f772f71ecfabe7aa01f5 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 27 Feb 2025 09:29:33 -0800 Subject: [PATCH 124/239] WIP: initial multi-layer test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 434 ++++++++++++++------ 1 file changed, 313 insertions(+), 121 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 88da489387..f67a25fb9c 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -13,11 +13,20 @@ from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables +from transformer_engine.pytorch.transformer import ( + TransformerLayer, +) from transformer_engine.pytorch.attention import ( + MultiheadAttention, DotProductAttention, InferenceParams, ) -from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.utils import ( + get_device_compute_capability, + init_method_normal, + scaled_init_method_normal, + is_bf16_compatible, +) from test_fused_attn import ( ModelConfig, reset_rng_states, @@ -40,7 +49,7 @@ model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias "infer_0": ModelConfig( - 4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 + 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), # "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), } @@ -81,7 +90,7 @@ def __init__( self.context_lens = torch.randint( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) - # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + #self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -202,15 +211,69 @@ def step(self, dynamic_fill: bool = True): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FlashAttention"]) #, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["TransformerLayer"])#, "MultiHeadAttention", "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) -def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): +def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): reset_rng_states() + #_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + #_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + #def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + # """Get cuda rng tracker.""" + # return _DUMMY_CUDA_RNG_STATE_TRACKER + + num_layers = 1 #2 + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, num_layers) logger = logging.getLogger("test_paged_attn") config = model_configs_infer[model] - num_layers = 2 - layer_number = 1 + hidden_size = config.head_dim_qk * config.num_heads + if module == "TransformerLayer": + model = [TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4*hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type="causal", + params_dtype=dtype, + #get_rng_state_tracker=get_dummy_cuda_rng_tracker, + attn_input_format="bshd", + ).cuda().eval() for layer_number in range(1, num_layers+1)] + if module == "MultiHeadAttention": + model = [MultiHeadAttention( + hidden_size=hidden_size, + num_attention_heads=config.num_heads, + kv_channels=config.head_dim_qk, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + attn_mask_type="causal", + num_gqa_groups=config.num_gqa_groups, + #get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + qkv_format="bshd", + ).cuda().eval() for layer_number in range(1, num_layers+1)] + if module == "DotProductAttention": + model = [DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format="bshd", + attn_mask_type="causal", + #get_rng_state_tracker=get_dummy_cuda_rng_tracker, + ).cuda().eval() for layer_number in range(1, num_layers+1)] + #model = torch.nn.Sequential(*model).cuda().eval() # figure out supported backends inference_params_qkv_format = "bshd" @@ -242,17 +305,17 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): config.max_seqlen_kv = 256 # create model - model = ( - DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - ) - .cuda() - .eval() - ) + #model = ( + # DotProductAttention( + # kv_channels=config.head_dim_qk, + # num_attention_heads=config.num_heads, + # num_gqa_groups=config.num_gqa_groups, + # layer_number=layer_number, + # attention_dropout=config.dropout_p, + # ) + # .cuda() + # .eval() + #) # generate data for all requests assert ( @@ -276,13 +339,30 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): # generate reference results logger.info("=== Generating all tokens at once ===") - full_output = model( - query_layer=q, - key_layer=k, - value_layer=v, - qkv_format="bshd", - attn_mask_type="causal", + #full_output = model( + # query_layer=q, + # key_layer=k, + # value_layer=v, + # qkv_format="bshd", + # attn_mask_type="causal", + #) + hidden_states = torch.randn( + (config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk), + dtype=dtype, + device="cuda", ) + #rotary_freqs = torch.randn((config.max_seqlen_kv, 1, 1, config.num_heads), dtype=torch.float, device="cuda") + h = hidden_states + for m in model: + h = m( + h, + self_attn_mask_type="causal", + #inference_params=inference_params, + #rotary_pos_emb=rotary_freqs, + ) + #print('full h', h[0,0,:4]) + #print('full h', h[1,6,:4]) + full_output = h # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -321,8 +401,33 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph): qkv_format=qkv_format, allow_query_conversion=backend != "FusedAttention", ) - inference_params.allocate_memory(layer_number, qkv_format) + #print('inference_params.cache_manager', inference_params.cache_manager) + for layer_number in range(1, num_layers+1): + inference_params.allocate_memory(layer_number, qkv_format) + reset_rng_states() + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, num_layers) + attn_mask_type = "padding_causal" if backend == "FlashAttention" else "padding" + model = [TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4*hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type=attn_mask_type, #"padding", #_causal", + #enc_dec_attn_mask_type="padding", #_causal", + params_dtype=dtype, + #get_rng_state_tracker=get_dummy_cuda_rng_tracker, + attn_input_format=qkv_format, + ).cuda().eval() for layer_number in range(1, num_layers+1)] + #model = torch.nn.Sequential(*model).cuda().eval() # graph the model if necessary if is_cuda_graph: t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") @@ -341,12 +446,12 @@ def gen_data(): return [ torch.ones( *shape, - config.num_heads, - config.head_dim_qk, + config.num_heads * config.head_dim_qk, device="cuda", dtype=dtype, ) - for _ in range(3) + #for _ in range(3) + for _ in range(1) ] sample_kwargs = {} @@ -365,13 +470,15 @@ def gen_data(): dtype=torch.int32, ) sample_kwargs["inference_params"] = inference_params - sample_kwargs["attn_mask_type"] = "padding" # _causal" + #sample_kwargs["attn_mask_type"] = "padding" # _causal" + #sample_kwargs["self_attn_mask_type"] = "padding" # _causal" sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - sample_kwargs["qkv_format"] = qkv_format + #sample_kwargs["qkv_format"] = qkv_format + #sample_kwargs["attn_input_format"] = qkv_format model = make_graphed_callables( - model, + model[0], gen_data(), num_warmup_iters=10, fp8_enabled=False, @@ -381,6 +488,7 @@ def gen_data(): sim.reset() inference_params.reset() step_dict = OrderedDict() + print('++++++++++++++== graphed ++++++++++') # simulate step by step # t-1: ... @@ -414,104 +522,140 @@ def gen_data(): batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() if qkv_format == "thd": - incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") - incremental_k = torch.Tensor().to(dtype=dtype, device="cuda") - incremental_v = torch.Tensor().to(dtype=dtype, device="cuda") + incremental_hidden_states = torch.Tensor().to(dtype=dtype, device="cuda") for i, seq in enumerate(sim.t_seq_ids): start = (sim.t_total_lens[i] - sim.step_lens[i]).item() end = sim.t_total_lens[i].item() - incremental_q = torch.cat([incremental_q, q[seq, start:end, :, :]], dim=0) - incremental_k = torch.cat( - [ - incremental_k, - k[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_qk), - ], - dim=0, - ) - incremental_v = torch.cat( - [ - incremental_v, - v[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_v), - ], - dim=0, - ) + incremental_hidden_states = torch.cat([incremental_hidden_states, hidden_states[seq, start:end, :]], dim=0) if is_cuda_graph: - incremental_q = torch.cat( - [ - incremental_q, - torch.zeros( - [max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], - dtype=dtype, - device=incremental_q.device, - ), - ], - dim=0, - ) - incremental_k = torch.cat( + incremental_hidden_states = torch.cat( [ - incremental_k, + incremental_hidden_states, torch.zeros( - [ - max_tokens - sum(sim.step_lens), - config.num_gqa_groups, - config.head_dim_v, - ], + [max_tokens - sum(sim.step_lens), config.num_heads * config.head_dim_qk], dtype=dtype, - device=incremental_k.device, - ), - ], - dim=0, - ) - incremental_v = torch.cat( - [ - incremental_v, - torch.zeros( - [ - max_tokens - sum(sim.step_lens), - config.num_gqa_groups, - config.head_dim_v, - ], - dtype=dtype, - device=incremental_v.device, + device=incremental_hidden_states.device, ), ], dim=0, ) + #incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") + #incremental_k = torch.Tensor().to(dtype=dtype, device="cuda") + #incremental_v = torch.Tensor().to(dtype=dtype, device="cuda") + #for i, seq in enumerate(sim.t_seq_ids): + # start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + # end = sim.t_total_lens[i].item() + # incremental_q = torch.cat([incremental_q, q[seq, start:end, :, :]], dim=0) + # incremental_k = torch.cat( + # [ + # incremental_k, + # k[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_qk), + # ], + # dim=0, + # ) + # incremental_v = torch.cat( + # [ + # incremental_v, + # v[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_v), + # ], + # dim=0, + # ) + #if is_cuda_graph: + # incremental_q = torch.cat( + # [ + # incremental_q, + # torch.zeros( + # [max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], + # dtype=dtype, + # device=incremental_q.device, + # ), + # ], + # dim=0, + # ) + # incremental_k = torch.cat( + # [ + # incremental_k, + # torch.zeros( + # [ + # max_tokens - sum(sim.step_lens), + # config.num_gqa_groups, + # config.head_dim_v, + # ], + # dtype=dtype, + # device=incremental_k.device, + # ), + # ], + # dim=0, + # ) + # incremental_v = torch.cat( + # [ + # incremental_v, + # torch.zeros( + # [ + # max_tokens - sum(sim.step_lens), + # config.num_gqa_groups, + # config.head_dim_v, + # ], + # dtype=dtype, + # device=incremental_v.device, + # ), + # ], + # dim=0, + # ) else: - incremental_q = torch.zeros( - batch_size, - max_seqlen_q, - config.num_heads, - config.head_dim_qk, - dtype=dtype, - device="cuda", - ) - incremental_k = torch.zeros( + #incremental_q = torch.zeros( + # batch_size, + # max_seqlen_q, + # config.num_heads, + # config.head_dim_qk, + # dtype=dtype, + # device="cuda", + #) + #incremental_k = torch.zeros( + # batch_size, + # max_seqlen_q, + # config.num_gqa_groups, + # config.head_dim_qk, + # dtype=dtype, + # device="cuda", + #) + #incremental_v = torch.zeros( + # batch_size, + # max_seqlen_q, + # config.num_gqa_groups, + # config.head_dim_v, + # dtype=dtype, + # device="cuda", + #) + #for i, seq in enumerate(sim.t_seq_ids): + # start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + # end = sim.t_total_lens[i].item() + # incremental_q[i, : sim.step_lens[i], :, :] = q[seq, start:end, :, :] + # incremental_k[i, : sim.step_lens[i], :, :] = k[seq, start:end, :, :] + # incremental_v[i, : sim.step_lens[i], :, :] = v[seq, start:end, :, :] + #if qkv_format == "sbhd": + # incremental_q, incremental_k, incremental_v = [ + # x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] + # ] + incremental_hidden_states = torch.zeros( batch_size, max_seqlen_q, - config.num_gqa_groups, - config.head_dim_qk, - dtype=dtype, - device="cuda", - ) - incremental_v = torch.zeros( - batch_size, - max_seqlen_q, - config.num_gqa_groups, - config.head_dim_v, + config.num_heads * config.head_dim_qk, dtype=dtype, device="cuda", ) + print('sim.t_seq_ids', sim.t_seq_ids) + print(sim.t_total_lens, sim.step_lens) for i, seq in enumerate(sim.t_seq_ids): start = (sim.t_total_lens[i] - sim.step_lens[i]).item() end = sim.t_total_lens[i].item() - incremental_q[i, : sim.step_lens[i], :, :] = q[seq, start:end, :, :] - incremental_k[i, : sim.step_lens[i], :, :] = k[seq, start:end, :, :] - incremental_v[i, : sim.step_lens[i], :, :] = v[seq, start:end, :, :] + incremental_hidden_states[i, : sim.step_lens[i], :] = hidden_states[seq, start:end, :] + #print(hidden_states[0,0,:4]) + #print(incremental_hidden_states[0,0,:4]) + #print(hidden_states[1,6,:4]) + #print(incremental_hidden_states[1,6,:4]) if qkv_format == "sbhd": - incremental_q, incremental_k, incremental_v = [ - x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] - ] + incremental_hidden_states = incremental_hidden_states.transpose(0, 1).contiguous() # run step batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size @@ -523,24 +667,65 @@ def gen_data(): inference_params.pre_step(step_dict) if inference_params.is_paged: inference_params.cache_manager.print_cache() - line_output = model( - incremental_q, - incremental_k, - incremental_v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - inference_params=inference_params, - attn_mask_type="padding", # _causal", - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - qkv_format=qkv_format, - ) + #line_output = model( + # incremental_q, + # incremental_k, + # incremental_v, + # cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_kv=cu_seqlens_kv, + # inference_params=inference_params, + # attn_mask_type="padding", # _causal", + # max_seqlen_q=max_seqlen_q, + # max_seqlen_kv=config.max_seqlen_kv, + # qkv_format=qkv_format, + #) + h = incremental_hidden_states + if not is_cuda_graph: + for m in model: + h = m( + h, + #rotary_pos_emb=rotary_freqs, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + inference_params=inference_params, + #self_attn_mask_type="padding", #_causal", + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + ) + #print('full h', h[0,0,:4]) + #print('full h', h[1,6,:4]) + else: + for _ in range(num_layers): + h = model( + h, + #rotary_pos_emb=rotary_freqs, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + inference_params=inference_params, + #self_attn_mask_type="padding", #_causal", + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + ) + line_output = h + print('cu_seqlens_q', cu_seqlens_q) + print('cu_seqlens_kv', cu_seqlens_kv) + #line_output = model( + # hidden_states=incremental_hidden_states, + # cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_kv=cu_seqlens_kv, + # inference_params=inference_params, + # self_attn_mask_type="padding", # _causal", + # max_seqlen_q=max_seqlen_q, + # max_seqlen_kv=config.max_seqlen_kv, + # #qkv_format=qkv_format, + # #rotary_pos_emb=rotary_freqs, + # ) # compare results if backend != "FlashAttention": tols = { torch.float32: 1e-3, - torch.half: 1e-3, + torch.half: 3e-3, torch.bfloat16: 1e-2, } else: @@ -552,6 +737,10 @@ def gen_data(): for i, seq in enumerate(sim.t_seq_ids): token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 if qkv_format == "bshd": + print(i, seq, sim.t_total_lens, sim.step_lens, token_index) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[i, token_index, :4]) + print(line_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], # line_output[:sim.step_lens[i] - 1, i, :], @@ -561,6 +750,9 @@ def gen_data(): rtol=tols[dtype], ) if qkv_format == "sbhd": + print(i, seq, sim.t_total_lens, sim.step_lens) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(line_output[token_index, i, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], # line_output[:sim.step_lens[i] - 1, i, :], From 0c72a612093d4422167eb91c27b55cfbd43df954 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 27 Feb 2025 09:30:12 -0800 Subject: [PATCH 125/239] WIP: minor clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 494 +++++++------------- 1 file changed, 166 insertions(+), 328 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index f67a25fb9c..f001be5e47 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -17,7 +17,6 @@ TransformerLayer, ) from transformer_engine.pytorch.attention import ( - MultiheadAttention, DotProductAttention, InferenceParams, ) @@ -206,74 +205,32 @@ def step(self, dynamic_fill: bool = True): self.t_batch_size = len(self.t_seq_ids) self.t_total_lens = self.t_ctx_lens + self.t_gen_lens +def generate_args(num_tensors: int, shapes: List, dtype: torch.dtype, warmup: bool = False): + if len(shapes) == 1: + shapes = shapes * num_tensors + func = torch.ones if warmup else torch.randn + scale = 1 if warmup else 0.1 + return [ + scale * func( + *shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] @pytest.mark.parametrize("dtype", [torch.float16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FlashAttention"]) #, "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("module", ["TransformerLayer"])#, "MultiHeadAttention", "DotProductAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"]) #, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["TransformerLayer"])#, "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): - reset_rng_states() - #_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - #_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - #def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: - # """Get cuda rng tracker.""" - # return _DUMMY_CUDA_RNG_STATE_TRACKER - - num_layers = 1 #2 - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, num_layers) logger = logging.getLogger("test_paged_attn") - + sigma = 0.023 + num_layers = 1 #2 config = model_configs_infer[model] - hidden_size = config.head_dim_qk * config.num_heads - if module == "TransformerLayer": - model = [TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4*hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type="causal", - params_dtype=dtype, - #get_rng_state_tracker=get_dummy_cuda_rng_tracker, - attn_input_format="bshd", - ).cuda().eval() for layer_number in range(1, num_layers+1)] - if module == "MultiHeadAttention": - model = [MultiHeadAttention( - hidden_size=hidden_size, - num_attention_heads=config.num_heads, - kv_channels=config.head_dim_qk, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - attn_mask_type="causal", - num_gqa_groups=config.num_gqa_groups, - #get_rng_state_tracker=get_dummy_cuda_rng_tracker, - params_dtype=dtype, - qkv_format="bshd", - ).cuda().eval() for layer_number in range(1, num_layers+1)] - if module == "DotProductAttention": - model = [DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - qkv_format="bshd", - attn_mask_type="causal", - #get_rng_state_tracker=get_dummy_cuda_rng_tracker, - ).cuda().eval() for layer_number in range(1, num_layers+1)] - #model = torch.nn.Sequential(*model).cuda().eval() # figure out supported backends inference_params_qkv_format = "bshd" @@ -304,65 +261,71 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config.max_seqlen_q = 256 config.max_seqlen_kv = 256 - # create model - #model = ( - # DotProductAttention( - # kv_channels=config.head_dim_qk, - # num_attention_heads=config.num_heads, - # num_gqa_groups=config.num_gqa_groups, - # layer_number=layer_number, - # attention_dropout=config.dropout_p, - # ) - # .cuda() - # .eval() - #) + # create full model + reset_rng_states() + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, num_layers) + hidden_size = config.head_dim_qk * config.num_heads + if module == "TransformerLayer": + model = [TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4*hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type="causal", + params_dtype=dtype, + attn_input_format="bshd", + ).cuda().eval() for layer_number in range(1, num_layers+1)] + if module == "DotProductAttention": + model = [DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format="bshd", + attn_mask_type="causal", + ).cuda().eval() for layer_number in range(1, num_layers+1)] # generate data for all requests assert ( config.max_seqlen_q == config.max_seqlen_kv ), "This test only simulates max_seqlen_q = max_seqlen_kv." - q = 0.1 * torch.randn( - (config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk), - dtype=dtype, - device="cuda", - ) - k = 0.1 * torch.randn( - (config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk), - dtype=dtype, - device="cuda", - ) - v = 0.1 * torch.randn( - (config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v), - dtype=dtype, - device="cuda", - ) + shapes = [] + if module == "TransformerLayer": + shapes.append([config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk]) + num_tensors = 1 + if module == "DotProductAttention": + shapes.append([config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk]) + shapes.append([config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk]) + shapes.append([config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v]) + num_tensors = 3 + full_inputs = generate_args(num_tensors, shapes, dtype, warmup=False) # generate reference results logger.info("=== Generating all tokens at once ===") - #full_output = model( - # query_layer=q, - # key_layer=k, - # value_layer=v, - # qkv_format="bshd", - # attn_mask_type="causal", - #) - hidden_states = torch.randn( - (config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk), - dtype=dtype, - device="cuda", - ) + if module == "DotProductAttention": + full_output = full_inputs + for m in model: + full_output = m( + *full_output if isinstance(full_output, List) else full_output, + ) #rotary_freqs = torch.randn((config.max_seqlen_kv, 1, 1, config.num_heads), dtype=torch.float, device="cuda") - h = hidden_states - for m in model: - h = m( - h, - self_attn_mask_type="causal", - #inference_params=inference_params, - #rotary_pos_emb=rotary_freqs, + if module == "TransformerLayer": + full_output = full_inputs + for m in model: + full_output = m( + *full_output if isinstance(full_output, List) else full_output, + #rotary_pos_emb=rotary_freqs, ) - #print('full h', h[0,0,:4]) - #print('full h', h[1,6,:4]) - full_output = h + #print('full h', h[0,0,:4]) + #print('full h', h[1,6,:4]) # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -401,33 +364,42 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda qkv_format=qkv_format, allow_query_conversion=backend != "FusedAttention", ) - #print('inference_params.cache_manager', inference_params.cache_manager) for layer_number in range(1, num_layers+1): inference_params.allocate_memory(layer_number, qkv_format) + # create inference model reset_rng_states() - sigma = 0.023 init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, num_layers) attn_mask_type = "padding_causal" if backend == "FlashAttention" else "padding" - model = [TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4*hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type=attn_mask_type, #"padding", #_causal", - #enc_dec_attn_mask_type="padding", #_causal", - params_dtype=dtype, - #get_rng_state_tracker=get_dummy_cuda_rng_tracker, - attn_input_format=qkv_format, - ).cuda().eval() for layer_number in range(1, num_layers+1)] - #model = torch.nn.Sequential(*model).cuda().eval() + if module == "TransformerLayer": + model = [TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4*hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type=attn_mask_type, #"padding", #_causal", + #enc_dec_attn_mask_type="padding", #_causal", + params_dtype=dtype, + attn_input_format=qkv_format, + ).cuda().eval() for layer_number in range(1, num_layers+1)] + if module == "DotProductAttention": + model = [DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + ).cuda().eval() for layer_number in range(1, num_layers+1)] + # graph the model if necessary if is_cuda_graph: t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") @@ -442,17 +414,14 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if qkv_format == "thd": shape = [config.batch_size * config.max_ctx_len] - def gen_data(): - return [ - torch.ones( - *shape, - config.num_heads * config.head_dim_qk, - device="cuda", - dtype=dtype, - ) - #for _ in range(3) - for _ in range(1) - ] + shapes = [] + if module == "TransformerLayer": + shapes.append([*shape, config.num_heads * config.head_dim_qk]) + if module == "DotProductAttention": + shapes.append([*shape, config.num_heads, config.head_dim_qk]) + shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk]) + shapes.append([*shape, config.num_gqa_groups, config.head_dim_v]) + sample_args = generate_args(num_tensors, shapes, dtype, warmup=True) sample_kwargs = {} sample_kwargs["cu_seqlens_q"] = torch.linspace( @@ -470,16 +439,12 @@ def gen_data(): dtype=torch.int32, ) sample_kwargs["inference_params"] = inference_params - #sample_kwargs["attn_mask_type"] = "padding" # _causal" - #sample_kwargs["self_attn_mask_type"] = "padding" # _causal" sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - #sample_kwargs["qkv_format"] = qkv_format - #sample_kwargs["attn_input_format"] = qkv_format model = make_graphed_callables( model[0], - gen_data(), + sample_args, num_warmup_iters=10, fp8_enabled=False, sample_kwargs=sample_kwargs, @@ -488,7 +453,6 @@ def gen_data(): sim.reset() inference_params.reset() step_dict = OrderedDict() - print('++++++++++++++== graphed ++++++++++') # simulate step by step # t-1: ... @@ -522,140 +486,45 @@ def gen_data(): batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() if qkv_format == "thd": - incremental_hidden_states = torch.Tensor().to(dtype=dtype, device="cuda") - for i, seq in enumerate(sim.t_seq_ids): - start = (sim.t_total_lens[i] - sim.step_lens[i]).item() - end = sim.t_total_lens[i].item() - incremental_hidden_states = torch.cat([incremental_hidden_states, hidden_states[seq, start:end, :]], dim=0) - if is_cuda_graph: - incremental_hidden_states = torch.cat( - [ - incremental_hidden_states, - torch.zeros( - [max_tokens - sum(sim.step_lens), config.num_heads * config.head_dim_qk], - dtype=dtype, - device=incremental_hidden_states.device, - ), - ], - dim=0, - ) - #incremental_q = torch.Tensor().to(dtype=dtype, device="cuda") - #incremental_k = torch.Tensor().to(dtype=dtype, device="cuda") - #incremental_v = torch.Tensor().to(dtype=dtype, device="cuda") - #for i, seq in enumerate(sim.t_seq_ids): - # start = (sim.t_total_lens[i] - sim.step_lens[i]).item() - # end = sim.t_total_lens[i].item() - # incremental_q = torch.cat([incremental_q, q[seq, start:end, :, :]], dim=0) - # incremental_k = torch.cat( - # [ - # incremental_k, - # k[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_qk), - # ], - # dim=0, - # ) - # incremental_v = torch.cat( - # [ - # incremental_v, - # v[seq, start:end, :, :].view(-1, config.num_gqa_groups, config.head_dim_v), - # ], - # dim=0, - # ) - #if is_cuda_graph: - # incremental_q = torch.cat( - # [ - # incremental_q, - # torch.zeros( - # [max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], - # dtype=dtype, - # device=incremental_q.device, - # ), - # ], - # dim=0, - # ) - # incremental_k = torch.cat( - # [ - # incremental_k, - # torch.zeros( - # [ - # max_tokens - sum(sim.step_lens), - # config.num_gqa_groups, - # config.head_dim_v, - # ], - # dtype=dtype, - # device=incremental_k.device, - # ), - # ], - # dim=0, - # ) - # incremental_v = torch.cat( - # [ - # incremental_v, - # torch.zeros( - # [ - # max_tokens - sum(sim.step_lens), - # config.num_gqa_groups, - # config.head_dim_v, - # ], - # dtype=dtype, - # device=incremental_v.device, - # ), - # ], - # dim=0, - # ) + incremental_inputs = [] + for i in range(num_tensors): + inp = full_inputs[i] + inc_inp = torch.Tensor().to(dtype=dtype, device="cuda") + for i, seq in enumerate(sim.t_seq_ids): + start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + end = sim.t_total_lens[i].item() + inc_inp = torch.cat([inc_inp, inp[seq, start:end]], dim=0) + if is_cuda_graph: + inc_inp = torch.cat( + [ + inc_inp, + torch.zeros( + max_tokens - sum(sim.step_lens), *inp.shape[2:], + dtype=dtype, + device=inc_inp.device, + ), + ], + dim=0, + ) + incremental_inputs.append(inc_inp) else: - #incremental_q = torch.zeros( - # batch_size, - # max_seqlen_q, - # config.num_heads, - # config.head_dim_qk, - # dtype=dtype, - # device="cuda", - #) - #incremental_k = torch.zeros( - # batch_size, - # max_seqlen_q, - # config.num_gqa_groups, - # config.head_dim_qk, - # dtype=dtype, - # device="cuda", - #) - #incremental_v = torch.zeros( - # batch_size, - # max_seqlen_q, - # config.num_gqa_groups, - # config.head_dim_v, - # dtype=dtype, - # device="cuda", - #) - #for i, seq in enumerate(sim.t_seq_ids): - # start = (sim.t_total_lens[i] - sim.step_lens[i]).item() - # end = sim.t_total_lens[i].item() - # incremental_q[i, : sim.step_lens[i], :, :] = q[seq, start:end, :, :] - # incremental_k[i, : sim.step_lens[i], :, :] = k[seq, start:end, :, :] - # incremental_v[i, : sim.step_lens[i], :, :] = v[seq, start:end, :, :] - #if qkv_format == "sbhd": - # incremental_q, incremental_k, incremental_v = [ - # x.transpose(0, 1) for x in [incremental_q, incremental_k, incremental_v] - # ] - incremental_hidden_states = torch.zeros( - batch_size, - max_seqlen_q, - config.num_heads * config.head_dim_qk, - dtype=dtype, - device="cuda", - ) - print('sim.t_seq_ids', sim.t_seq_ids) - print(sim.t_total_lens, sim.step_lens) - for i, seq in enumerate(sim.t_seq_ids): - start = (sim.t_total_lens[i] - sim.step_lens[i]).item() - end = sim.t_total_lens[i].item() - incremental_hidden_states[i, : sim.step_lens[i], :] = hidden_states[seq, start:end, :] - #print(hidden_states[0,0,:4]) - #print(incremental_hidden_states[0,0,:4]) - #print(hidden_states[1,6,:4]) - #print(incremental_hidden_states[1,6,:4]) - if qkv_format == "sbhd": - incremental_hidden_states = incremental_hidden_states.transpose(0, 1).contiguous() + incremental_inputs = [] + for i in range(num_tensors): + inp = full_inputs[i] + inc_inp = torch.zeros( + batch_size, + max_seqlen_q, + *inp.shape[2:], + dtype=dtype, + device="cuda", + ) + for i, seq in enumerate(sim.t_seq_ids): + start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + end = sim.t_total_lens[i].item() + inc_inp[i, : sim.step_lens[i], :] = inp[seq, start:end] + if qkv_format == "sbhd": + inc_inp = inc_inp.transpose(0, 1).contiguous() + incremental_inputs.append(inc_inp) # run step batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size @@ -667,59 +536,28 @@ def gen_data(): inference_params.pre_step(step_dict) if inference_params.is_paged: inference_params.cache_manager.print_cache() - #line_output = model( - # incremental_q, - # incremental_k, - # incremental_v, - # cu_seqlens_q=cu_seqlens_q, - # cu_seqlens_kv=cu_seqlens_kv, - # inference_params=inference_params, - # attn_mask_type="padding", # _causal", - # max_seqlen_q=max_seqlen_q, - # max_seqlen_kv=config.max_seqlen_kv, - # qkv_format=qkv_format, - #) - h = incremental_hidden_states if not is_cuda_graph: + incremental_output = incremental_inputs for m in model: - h = m( - h, - #rotary_pos_emb=rotary_freqs, + incremental_output = m( + *incremental_output if isinstance(incremental_output, List) else incremental_output, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, - #self_attn_mask_type="padding", #_causal", max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, - ) - #print('full h', h[0,0,:4]) - #print('full h', h[1,6,:4]) + ) else: + incremental_output = incremental_inputs for _ in range(num_layers): - h = model( - h, - #rotary_pos_emb=rotary_freqs, + incremental_output = model( + *incremental_output if isinstance(incremental_output, List) else incremental_output, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, - #self_attn_mask_type="padding", #_causal", max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, - ) - line_output = h - print('cu_seqlens_q', cu_seqlens_q) - print('cu_seqlens_kv', cu_seqlens_kv) - #line_output = model( - # hidden_states=incremental_hidden_states, - # cu_seqlens_q=cu_seqlens_q, - # cu_seqlens_kv=cu_seqlens_kv, - # inference_params=inference_params, - # self_attn_mask_type="padding", # _causal", - # max_seqlen_q=max_seqlen_q, - # max_seqlen_kv=config.max_seqlen_kv, - # #qkv_format=qkv_format, - # #rotary_pos_emb=rotary_freqs, - # ) + ) # compare results if backend != "FlashAttention": @@ -739,37 +577,37 @@ def gen_data(): if qkv_format == "bshd": print(i, seq, sim.t_total_lens, sim.step_lens, token_index) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[i, token_index, :4]) - print(line_output[i, sim.step_lens[i] - 1, :4]) + print(incremental_output[i, token_index, :4]) + print(incremental_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - # line_output[:sim.step_lens[i] - 1, i, :], + # incremental_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], - line_output[i, token_index, :], + incremental_output[i, token_index, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "sbhd": print(i, seq, sim.t_total_lens, sim.step_lens) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[token_index, i, :4]) + print(incremental_output[token_index, i, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - # line_output[:sim.step_lens[i] - 1, i, :], + # incremental_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], - line_output[token_index, i, :], + incremental_output[token_index, i, :], atol=tols[dtype], rtol=tols[dtype], ) if qkv_format == "thd": # print('i ', i, seq, cu_seqlens_q) # print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - # print(line_output[cu_seqlens_q[i + 1] - 1, :4]) + # print(incremental_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - # line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], + # incremental_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], full_output[seq, sim.t_total_lens[i] - 1, :], - line_output[cu_seqlens_q[i + 1] - 1, :], + incremental_output[cu_seqlens_q[i + 1] - 1, :], atol=tols[dtype], rtol=tols[dtype], ) From e06f4a1914b58d381cf242a3bbdaac6f19040368 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:30:56 +0000 Subject: [PATCH 126/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 172 ++++++++++++-------- transformer_engine/pytorch/attention.py | 5 +- transformer_engine/pytorch/graph.py | 6 +- 3 files changed, 115 insertions(+), 68 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index f001be5e47..cf413d3864 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -89,7 +89,7 @@ def __init__( self.context_lens = torch.randint( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) - #self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -205,13 +205,15 @@ def step(self, dynamic_fill: bool = True): self.t_batch_size = len(self.t_seq_ids) self.t_total_lens = self.t_ctx_lens + self.t_gen_lens + def generate_args(num_tensors: int, shapes: List, dtype: torch.dtype, warmup: bool = False): if len(shapes) == 1: shapes = shapes * num_tensors func = torch.ones if warmup else torch.randn scale = 1 if warmup else 0.1 return [ - scale * func( + scale + * func( *shapes[i], device="cuda", dtype=dtype, @@ -219,17 +221,18 @@ def generate_args(num_tensors: int, shapes: List, dtype: torch.dtype, warmup: bo for i in range(num_tensors) ] + @pytest.mark.parametrize("dtype", [torch.float16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"]) #, "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("module", ["TransformerLayer"])#, "DotProductAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["TransformerLayer"]) # , "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): logger = logging.getLogger("test_paged_attn") sigma = 0.023 - num_layers = 1 #2 + num_layers = 1 # 2 config = model_configs_infer[model] # figure out supported backends @@ -267,31 +270,41 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda output_layer_init_method = scaled_init_method_normal(sigma, num_layers) hidden_size = config.head_dim_qk * config.num_heads if module == "TransformerLayer": - model = [TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4*hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type="causal", - params_dtype=dtype, - attn_input_format="bshd", - ).cuda().eval() for layer_number in range(1, num_layers+1)] + model = [ + TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4 * hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type="causal", + params_dtype=dtype, + attn_input_format="bshd", + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] if module == "DotProductAttention": - model = [DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - qkv_format="bshd", - attn_mask_type="causal", - ).cuda().eval() for layer_number in range(1, num_layers+1)] + model = [ + DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format="bshd", + attn_mask_type="causal", + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] # generate data for all requests assert ( @@ -299,12 +312,20 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ), "This test only simulates max_seqlen_q = max_seqlen_kv." shapes = [] if module == "TransformerLayer": - shapes.append([config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk]) + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk] + ) num_tensors = 1 if module == "DotProductAttention": - shapes.append([config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk]) - shapes.append([config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk]) - shapes.append([config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v]) + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk] + ) + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk] + ) + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v] + ) num_tensors = 3 full_inputs = generate_args(num_tensors, shapes, dtype, warmup=False) @@ -316,16 +337,16 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda full_output = m( *full_output if isinstance(full_output, List) else full_output, ) - #rotary_freqs = torch.randn((config.max_seqlen_kv, 1, 1, config.num_heads), dtype=torch.float, device="cuda") + # rotary_freqs = torch.randn((config.max_seqlen_kv, 1, 1, config.num_heads), dtype=torch.float, device="cuda") if module == "TransformerLayer": full_output = full_inputs for m in model: full_output = m( *full_output if isinstance(full_output, List) else full_output, - #rotary_pos_emb=rotary_freqs, + # rotary_pos_emb=rotary_freqs, ) - #print('full h', h[0,0,:4]) - #print('full h', h[1,6,:4]) + # print('full h', h[0,0,:4]) + # print('full h', h[1,6,:4]) # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -364,7 +385,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda qkv_format=qkv_format, allow_query_conversion=backend != "FusedAttention", ) - for layer_number in range(1, num_layers+1): + for layer_number in range(1, num_layers + 1): inference_params.allocate_memory(layer_number, qkv_format) # create inference model @@ -373,32 +394,42 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda output_layer_init_method = scaled_init_method_normal(sigma, num_layers) attn_mask_type = "padding_causal" if backend == "FlashAttention" else "padding" if module == "TransformerLayer": - model = [TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4*hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type=attn_mask_type, #"padding", #_causal", - #enc_dec_attn_mask_type="padding", #_causal", - params_dtype=dtype, - attn_input_format=qkv_format, - ).cuda().eval() for layer_number in range(1, num_layers+1)] + model = [ + TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4 * hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type=attn_mask_type, # "padding", #_causal", + # enc_dec_attn_mask_type="padding", #_causal", + params_dtype=dtype, + attn_input_format=qkv_format, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] if module == "DotProductAttention": - model = [DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - ).cuda().eval() for layer_number in range(1, num_layers+1)] + model = [ + DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] # graph the model if necessary if is_cuda_graph: @@ -499,7 +530,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda [ inc_inp, torch.zeros( - max_tokens - sum(sim.step_lens), *inp.shape[2:], + max_tokens - sum(sim.step_lens), + *inp.shape[2:], dtype=dtype, device=inc_inp.device, ), @@ -540,7 +572,11 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda incremental_output = incremental_inputs for m in model: incremental_output = m( - *incremental_output if isinstance(incremental_output, List) else incremental_output, + *( + incremental_output + if isinstance(incremental_output, List) + else incremental_output + ), cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, @@ -551,7 +587,11 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda incremental_output = incremental_inputs for _ in range(num_layers): incremental_output = model( - *incremental_output if isinstance(incremental_output, List) else incremental_output, + *( + incremental_output + if isinstance(incremental_output, List) + else incremental_output + ), cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6410b60e21..82b53fdaf5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8414,7 +8414,10 @@ def forward( # Pre-allocate memory for key-values for inference # ================================================= - if inference_params is not None and self.layer_number not in inference_params.cache_manager.cache: + if ( + inference_params is not None + and self.layer_number not in inference_params.cache_manager.cache + ): inference_params.allocate_memory(self.layer_number, self.qkv_format) # ====================== diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 42b83dfb53..4637a8456c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -333,7 +333,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if callables[0].training and isinstance(arg, torch.Tensor) and arg.requires_grad: + if ( + callables[0].training + and isinstance(arg, torch.Tensor) + and arg.requires_grad + ): static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: From fb77772832aee468f4c5b9935526ea00c051829e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 27 Feb 2025 10:30:11 -0800 Subject: [PATCH 127/239] WIP: clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 278 ++++++++++---------- 1 file changed, 139 insertions(+), 139 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index cf413d3864..620710f393 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -205,12 +205,109 @@ def step(self, dynamic_fill: bool = True): self.t_batch_size = len(self.t_seq_ids) self.t_total_lens = self.t_ctx_lens + self.t_gen_lens +def get_model( + module: torch.nn.Module, + backend: str = "FusedAttention", + config: ModelConfig, + dtype: torch.dtype, + qkv_format: str = "bshd", + num_layers: int = 1, + mode: str = "reference", + ): + reset_rng_states() + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, num_layers) + + if mode == "reference": + attn_mask_type = "causal" + qkv_format = "bshd" + if mode == "inference": + attn_mask_type = "padding_causal" if backend == "FlashAttention" else "padding" + + if module == "TransformerLayer": + hidden_size = config.head_dim_qk * config.num_heads + model = [ + TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4 * hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type=attn_mask_type, # "padding", #_causal", + # enc_dec_attn_mask_type="padding", #_causal", + params_dtype=dtype, + attn_input_format=qkv_format, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] + if module == "DotProductAttention": + model = [ + DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] + return model + +def generate_args( + module: torch.nn.Module, + config: ModelConfig, + dtype: torch.dtype, + qkv_format: str = "bshd", + mode: str = "full_inputs", + ): + if mode == "full_inputs": + warmup = False + shapes = [] + if module == "TransformerLayer": + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk] + ) + if module == "DotProductAttention": + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk] + ) + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk] + ) + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v] + ) + elif mode == "sample_args": + warmup = True + shapes = [] + if qkv_format == "bshd": + shape = [config.batch_size, config.max_ctx_len] + if qkv_format == "sbhd": + shape = [config.max_ctx_len, config.batch_size] + if qkv_format == "thd": + shape = [config.batch_size * config.max_ctx_len] + if module == "TransformerLayer": + shapes.append([*shape, config.num_heads * config.head_dim_qk]) + if module == "DotProductAttention": + shapes.append([*shape, config.num_heads, config.head_dim_qk]) + shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk]) + shapes.append([*shape, config.num_gqa_groups, config.head_dim_v]) -def generate_args(num_tensors: int, shapes: List, dtype: torch.dtype, warmup: bool = False): - if len(shapes) == 1: - shapes = shapes * num_tensors func = torch.ones if warmup else torch.randn scale = 1 if warmup else 0.1 + num_tensors = len(shapes) return [ scale * func( @@ -221,6 +318,34 @@ def generate_args(num_tensors: int, shapes: List, dtype: torch.dtype, warmup: bo for i in range(num_tensors) ] +def get_tols(module, backend, dtype) + if module == "TransformerLayer": + if backend != "FlashAttention": + tols = { + torch.float32: 1e-3, + torch.half: 3e-3, + torch.bfloat16: 1e-2, + } + else: + tols = { + torch.float32: 1e-3, + torch.half: 4e-3, + torch.bfloat16: 1e-2, + } + if module == "DotProductAttention": + if backend != "FlashAttention": + tols = { + torch.float32: 1e-3, + torch.half: 1e-3, + torch.bfloat16: 1e-2, + } + else: + tols = { + torch.float32: 1e-3, + torch.half: 4e-3, + torch.bfloat16: 1e-2, + } + return tols[dtype] @pytest.mark.parametrize("dtype", [torch.float16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @@ -231,7 +356,6 @@ def generate_args(num_tensors: int, shapes: List, dtype: torch.dtype, warmup: bo @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): logger = logging.getLogger("test_paged_attn") - sigma = 0.023 num_layers = 1 # 2 config = model_configs_infer[model] @@ -265,69 +389,13 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config.max_seqlen_kv = 256 # create full model - reset_rng_states() - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, num_layers) - hidden_size = config.head_dim_qk * config.num_heads - if module == "TransformerLayer": - model = [ - TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4 * hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type="causal", - params_dtype=dtype, - attn_input_format="bshd", - ) - .cuda() - .eval() - for layer_number in range(1, num_layers + 1) - ] - if module == "DotProductAttention": - model = [ - DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - qkv_format="bshd", - attn_mask_type="causal", - ) - .cuda() - .eval() - for layer_number in range(1, num_layers + 1) - ] + model = get_model(module, backend, config, dtype, qkv_format, num_layers, mode="reference") # generate data for all requests assert ( config.max_seqlen_q == config.max_seqlen_kv ), "This test only simulates max_seqlen_q = max_seqlen_kv." - shapes = [] - if module == "TransformerLayer": - shapes.append( - [config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk] - ) - num_tensors = 1 - if module == "DotProductAttention": - shapes.append( - [config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk] - ) - shapes.append( - [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk] - ) - shapes.append( - [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v] - ) - num_tensors = 3 - full_inputs = generate_args(num_tensors, shapes, dtype, warmup=False) + full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs") # generate reference results logger.info("=== Generating all tokens at once ===") @@ -345,8 +413,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda *full_output if isinstance(full_output, List) else full_output, # rotary_pos_emb=rotary_freqs, ) - # print('full h', h[0,0,:4]) - # print('full h', h[1,6,:4]) # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -389,47 +455,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda inference_params.allocate_memory(layer_number, qkv_format) # create inference model - reset_rng_states() - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, num_layers) - attn_mask_type = "padding_causal" if backend == "FlashAttention" else "padding" - if module == "TransformerLayer": - model = [ - TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4 * hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type=attn_mask_type, # "padding", #_causal", - # enc_dec_attn_mask_type="padding", #_causal", - params_dtype=dtype, - attn_input_format=qkv_format, - ) - .cuda() - .eval() - for layer_number in range(1, num_layers + 1) - ] - if module == "DotProductAttention": - model = [ - DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - ) - .cuda() - .eval() - for layer_number in range(1, num_layers + 1) - ] + model = get_model(module, backend, config, dtype, qkv_format, num_layers, mode="inference") # graph the model if necessary if is_cuda_graph: @@ -438,22 +464,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist())) inference_params.pre_step(step_dict) - if qkv_format == "bshd": - shape = [config.batch_size, config.max_ctx_len] - if qkv_format == "sbhd": - shape = [config.max_ctx_len, config.batch_size] - if qkv_format == "thd": - shape = [config.batch_size * config.max_ctx_len] - - shapes = [] - if module == "TransformerLayer": - shapes.append([*shape, config.num_heads * config.head_dim_qk]) - if module == "DotProductAttention": - shapes.append([*shape, config.num_heads, config.head_dim_qk]) - shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk]) - shapes.append([*shape, config.num_gqa_groups, config.head_dim_v]) - sample_args = generate_args(num_tensors, shapes, dtype, warmup=True) - + sample_args = generate_args(module, config, dtype, qkv_format=qkv_format, mode="sample_args") sample_kwargs = {} sample_kwargs["cu_seqlens_q"] = torch.linspace( 0, @@ -600,18 +611,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ) # compare results - if backend != "FlashAttention": - tols = { - torch.float32: 1e-3, - torch.half: 3e-3, - torch.bfloat16: 1e-2, - } - else: - tols = { - torch.float32: 1e-3, - torch.half: 4e-3, - torch.bfloat16: 1e-2, - } + tol = get_tols(module, backend, dtype) for i, seq in enumerate(sim.t_seq_ids): token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 if qkv_format == "bshd": @@ -624,8 +624,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # incremental_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[i, token_index, :], - atol=tols[dtype], - rtol=tols[dtype], + atol=tol, + rtol=tol, ) if qkv_format == "sbhd": print(i, seq, sim.t_total_lens, sim.step_lens) @@ -636,8 +636,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # incremental_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[token_index, i, :], - atol=tols[dtype], - rtol=tols[dtype], + atol=tol, + rtol=tol, ) if qkv_format == "thd": # print('i ', i, seq, cu_seqlens_q) @@ -648,8 +648,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # incremental_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[cu_seqlens_q[i + 1] - 1, :], - atol=tols[dtype], - rtol=tols[dtype], + atol=tol, + rtol=tol, ) sim.t += 1 sim.t_gen_lens = sim.t_gen_lens + 1 From eb28c650c688f7734adab16087ea309e193404a9 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 27 Feb 2025 15:23:36 -0800 Subject: [PATCH 128/239] Update build test CUDA version to 12.1 (#1517) Signed-off-by: Tim Moon --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4be7a30a86..3b6202263b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,7 @@ jobs: name: 'Core' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu22.04 + image: nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' @@ -35,7 +35,7 @@ jobs: name: 'PyTorch' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04 + image: nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' From 9654931c81c7ef342fbccbe2ff1150245d30304d Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Thu, 27 Feb 2025 16:13:15 -0800 Subject: [PATCH 129/239] Support vectorized local reduction for p2p-based ReduceScatter overlap (#1452) * Support vectorized local reduction for p2p-based ReduceScatter overlap Signed-off-by: Sangkug Lym * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../userbuffers/userbuffers.cu | 141 ++++++++++++------ 1 file changed, 98 insertions(+), 43 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 735148a811..e6ec1f59f7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -9,10 +9,9 @@ #include #if __CUDA_ARCH__ >= 800 -#include -#define half nv_bfloat16 +#define half_dtype nv_bfloat16 #else -#include +#define half_dtype half #endif #include @@ -20,6 +19,7 @@ #include #include "common/util/system.h" +#include "common/util/vectorized_pointwise.h" #include "userbuffers.h" #define MAX_THREADS 1024 @@ -116,11 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -200,11 +200,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -311,11 +311,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -378,11 +378,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -780,7 +780,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; int lastSM = 0; - half hscale = (half)*scale; + half_dtype hscale = (half_dtype)*scale; if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; @@ -823,13 +823,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 0; i < RANKS; i++) { fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half_dtype)(x[j]); } int hline = 2 * line; (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = @@ -855,7 +855,7 @@ __global__ void __launch_bounds__(MAX_THREADS) int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; int lastSM = 0; - half hscale = (half)*scale; + half_dtype hscale = (half_dtype)*scale; if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; @@ -919,13 +919,14 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 0; i < RANKS; i++) { fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) + s[j] += hscale * (half_dtype)(x[j]); } (reinterpret_cast(outbuf))[index1_out] = sum[0]; (reinterpret_cast(outbuf))[index2_out] = sum[1]; @@ -988,11 +989,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -1078,11 +1079,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -1169,11 +1170,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -2597,30 +2598,57 @@ void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); } -template +template __global__ void __launch_bounds__(MAX_THREADS / 4) reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, - const int num_inputs, const int input_size) { - const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + const int num_inputs, const int input_size, + const int num_aligned_elements_per_input, + const int tot_input_size) { fp8type *inputs_fp8 = reinterpret_cast(inputs); - float accum_buf = static_cast(inputs_fp8[tid]) * (*scale); + half_dtype *output_half = reinterpret_cast(output); + + transformer_engine::VectorizedLoader loader(inputs_fp8, tot_input_size); + transformer_engine::VectorizedStorer storer(output_half, input_size); + + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + if (tid >= num_aligned_elements_per_input) { + return; + } + float accum_buf[nvec]; + + loader.load(tid, tot_input_size); #pragma unroll - for (int i = 1; i < num_inputs; i++) { - accum_buf += static_cast(inputs_fp8[tid + input_size * i]) * (*scale); + for (int i = 0; i < nvec; ++i) { + accum_buf[i] = static_cast(loader.separate()[i]) * (*scale); } - half *output_half = reinterpret_cast(output); - output_half[tid] = (half)accum_buf; + for (int input_id = 1; input_id < num_inputs; ++input_id) { + loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] += static_cast(loader.separate()[i]) * (*scale); + } + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + storer.separate()[i] = static_cast(accum_buf[i]); + } + storer.store(tid, input_size); } template void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream) { + constexpr int nvec = 32; + assert(input_size % nvec == 0); + const int num_aligned_elements_per_input = input_size / nvec; + const int tot_input_size = input_size * num_inputs; size_t num_threads = MAX_THREADS / 4; - size_t num_blocks = (input_size + num_threads - 1) / num_threads; + size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads; dim3 block(num_threads); dim3 grid(num_blocks); - reduce_fp8_in_bf16_out_cuda - <<>>(inputs, output, scale, num_inputs, input_size); + reduce_fp8_in_bf16_out_cuda + <<>>(inputs, output, scale, num_inputs, input_size, + num_aligned_elements_per_input, tot_input_size); } template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, @@ -2630,23 +2658,50 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream); +template __global__ void __launch_bounds__(MAX_THREADS / 4) - reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) { + reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size, + const int num_aligned_elements_per_input, const int tot_input_size) { + half_dtype *inputs_half = reinterpret_cast(inputs); + half_dtype *output_half = reinterpret_cast(output); + + transformer_engine::VectorizedLoader loader(inputs_half, tot_input_size); + transformer_engine::VectorizedStorer storer(output_half, input_size); + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; - half *inputs_half = reinterpret_cast(inputs); - float accum_buf = static_cast(inputs_half[tid]); + if (tid >= num_aligned_elements_per_input) { + return; + } + float accum_buf[nvec]; + + loader.load(tid, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] = static_cast(loader.separate()[i]); + } + for (int input_id = 1; input_id < num_inputs; ++input_id) { + loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] += static_cast(loader.separate()[i]); + } + } #pragma unroll - for (int i = 1; i < num_inputs; i++) { - accum_buf += static_cast(inputs_half[tid + input_size * i]); + for (int i = 0; i < nvec; ++i) { + storer.separate()[i] = static_cast(accum_buf[i]); } - half *output_half = reinterpret_cast(output); - output_half[tid] = (half)accum_buf; + storer.store(tid, input_size); } void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { + constexpr int nvec = 32; + assert(input_size % nvec == 0); + const int num_aligned_elements_per_input = input_size / nvec; + const int tot_input_size = input_size * num_inputs; size_t num_threads = MAX_THREADS / 4; - size_t num_blocks = (input_size + num_threads - 1) / num_threads; + size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads; dim3 block(num_threads); dim3 grid(num_blocks); - reduce_bf16_cuda<<>>(inputs, output, num_inputs, input_size); + reduce_bf16_cuda<<>>( + inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); } From 97344d669670234de6517053b3eebefb05e3a40b Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Thu, 27 Feb 2025 17:10:34 -0800 Subject: [PATCH 130/239] TP-RS local reduction: fix lint err (#1520) * TP-RS local reduction: fix lint err Signed-off-by: Sangkug Lym * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/comm_gemm_overlap/userbuffers/userbuffers.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index e6ec1f59f7..58de844858 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2614,7 +2614,7 @@ __global__ void __launch_bounds__(MAX_THREADS / 4) if (tid >= num_aligned_elements_per_input) { return; } - float accum_buf[nvec]; + float accum_buf[nvec]; // NOLINT(*) loader.load(tid, tot_input_size); #pragma unroll @@ -2672,7 +2672,7 @@ __global__ void __launch_bounds__(MAX_THREADS / 4) if (tid >= num_aligned_elements_per_input) { return; } - float accum_buf[nvec]; + float accum_buf[nvec]; // NOLINT(*) loader.load(tid, tot_input_size); #pragma unroll From 9588109d4c412aa58bc08f523421d82f71a0cc15 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 28 Feb 2025 17:22:37 +0530 Subject: [PATCH 131/239] Fix shape of new quantized tensor in `make_like` (#1515) * Fix quantized tensor shape Signed-off-by: Kirthi Shankar Sivamani * add shape to make_like; add test for chunk Signed-off-by: Kirthi Shankar Sivamani * Fix typo from suggestion Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 30 +++++++++++++++++++ .../pytorch/tensor/float8_tensor.py | 9 ++++-- .../pytorch/tensor/quantized_tensor.py | 3 +- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 56b01f1dbc..9d01527ac5 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -161,6 +161,36 @@ def test_basic_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]]) + def test_chunk_op( + self, + dims: DimsType, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test for ops for which shape of inputs and outputs differ.""" + + # Initialize random data + dims = _to_list(dims) + x_ref = torch.randn(dims, dtype=dtype, device="cpu") + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0) + + # Get chunks. + chunk1, chunk2 = x_fp8.chunk(2, dim=0) + + # Test chunks. + torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0) + torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0) + + # Check shapes. + assert ( + chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk1" + assert ( + chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk2" + def test_inplace_ops( self, dims: DimsType = 23, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 49bf4facfa..c9e65bd93a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -402,7 +402,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + return [ + Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) + for split_tensor in func_out + ] if func == aten.new_zeros.default: tensor = args[0] data = tensor._data @@ -412,7 +415,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data @@ -422,7 +425,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index ef21412ca7..b540cd91a1 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -433,7 +433,8 @@ def make_like( data. """ - shape = shape if shape is not None else tensor.shape + if shape is None: + shape = data.shape if data is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() if data is not None: From 303c6d16203b3cb01675f7adb7c21956f140e0ee Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 28 Feb 2025 17:23:13 +0530 Subject: [PATCH 132/239] Enforce PyTorch version 2.1 and run attention tests with torch.compile (#1516) * Enforce torch 2.0 and run attn tests with torch.compile Signed-off-by: Kirthi Shankar Sivamani * replace torch.compile with jit_fuser Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_unittest/test.sh | 2 +- setup.py | 2 +- tests/pytorch/distributed/test_torch_fsdp2.py | 2 +- transformer_engine/pytorch/__init__.py | 16 ++++++++++++++-- transformer_engine/pytorch/attention.py | 16 ++++++++-------- transformer_engine/pytorch/jit.py | 16 ++++------------ transformer_engine/pytorch/ops/_common.py | 2 +- transformer_engine/pytorch/utils.py | 7 ------- 8 files changed, 30 insertions(+), 33 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 870e869795..fe36b33384 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 exit $FAIL diff --git a/setup.py b/setup.py index 856c518f79..996027bd9e 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch"]) + install_reqs.extend(["torch>=2.1"]) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable"]) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index bad09bf32a..f5c186a3bc 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -6,8 +6,8 @@ import pytest import subprocess from pathlib import Path +from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.utils import torch_version import torch diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 92250cd322..966115c29e 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -7,16 +7,25 @@ # pylint: disable=wrong-import-position,wrong-import-order import logging +import functools +import sys import importlib import importlib.util -import sys -import torch from importlib.metadata import version +from packaging.version import Version as PkgVersion + +import torch from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release + + def _load_library(): """Load shared library with Transformer Engine C extensions""" module_name = "transformer_engine_torch" @@ -60,6 +69,9 @@ def _load_library(): spec.loader.exec_module(solib) +assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." + + _load_library() from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7666d3f32b..cc92c1377d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens( return _cu_seqlens_cache[(batch_size, max_seqlen)] -@torch.compile +@jit_fuser def pack_tensor( indices: torch.Tensor, tensor: torch.Tensor, @@ -1409,7 +1409,7 @@ def pack_tensor( return packed -@torch.compile +@jit_fuser def pack_2_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1423,7 +1423,7 @@ def pack_2_tensors( return t1_packed, t2_packed -@torch.compile +@jit_fuser def pack_3_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1439,7 +1439,7 @@ def pack_3_tensors( return t1_packed, t2_packed, t3_packed -@torch.compile +@jit_fuser def unpack_tensor( indices: torch.Tensor, dim0: int, @@ -1462,7 +1462,7 @@ def unpack_tensor( return unpacked -@torch.compile +@jit_fuser def unpack_2_tensors( indices: torch.Tensor, dim0: int, @@ -1477,7 +1477,7 @@ def unpack_2_tensors( return t1_unpacked, t2_unpacked -@torch.compile +@jit_fuser def unpack_3_tensors( indices: torch.Tensor, dim0: int, @@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank -@torch.compile +@jit_fuser def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): """ Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. @@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): return chunk_ids -@torch.compile +@jit_fuser def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): """Reorder sequence chunk for A2A communication.""" if before_attn: diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index cda3939d6f..aae35ded68 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -10,28 +10,20 @@ # pylint: disable=unnecessary-lambda-assignment -jit_fuser = torch.jit.script +jit_fuser = lambda func: func if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile + # See: https://github.com/NVIDIA/TransformerEngine/issues/597 dropout_fuser = torch.jit.script if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): dropout_fuser = torch.compile + # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": - import torch._dynamo - - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( - f, recursive=recursive - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable +no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) def set_jit_fusion_options() -> None: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index b4631eb9a7..20e63e0e63 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -10,13 +10,13 @@ import torch from transformer_engine_torch import FP8TensorMeta +from .. import torch_version from ..fp8 import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor from ..utils import ( canonicalize_device, canonicalize_dtype, devices_match, - torch_version, ) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 4678097dc4..1922a7e867 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,7 +8,6 @@ import math import os from typing import Any, Callable, List, Optional, Tuple -from packaging.version import Version as PkgVersion import torch import transformer_engine.pytorch.cpp_extensions as ext @@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: # Pop NVTX range torch.cuda.nvtx.range_pop() - - -@functools.lru_cache(maxsize=None) -def torch_version() -> tuple[int, ...]: - """Get PyTorch version""" - return PkgVersion(str(torch.__version__)).release From d3efaebb6f116566bfa1b8918fbec6d57a751e0c Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 28 Feb 2025 12:51:43 -0800 Subject: [PATCH 133/239] Delete extra tensor objects after restoring float8 tensors (#1500) * delete extra tensor objects after restoring float8 tensors Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit fix Signed-off-by: Sudhakar Singh * fix the leak in float8tensor and mxfloat8tensor classes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * uncomment the fix Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/module/layernorm_linear.py | 3 +++ transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++++ transformer_engine/pytorch/module/linear.py | 3 +++ .../pytorch/tensor/_internal/float8_tensor_base.py | 4 ++-- .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 4 ++-- transformer_engine/pytorch/tensor/float8_tensor.py | 9 +++++++++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 9 +++++++++ 7 files changed, 32 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 01bda64101..007821038f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -448,6 +448,9 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 88eebc8e6c..f4ee0a1155 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -567,6 +567,10 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + # Since main_grad can be modified inplace, it should not be a part of saved_tensors fc1_weight_main_grad = ( ctx.fc1_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bae21eebfd..83dc652c62 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -354,6 +354,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking restore_from_saved(ctx.tensor_objects, saved_tensors) ) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 6b816db3b5..8ae45c9375 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -105,8 +105,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor """ tensors = [self._data, self._transpose] - # self._data = None - # self._transpose = None + self._data = None + self._transpose = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index d78bd55d9a..ea7fc3cf2f 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -100,8 +100,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB """ tensors = [self._rowwise_data, self._columnwise_data] - # self._rowwise_data = None - # self._columnwise_data = None + self._rowwise_data = None + self._columnwise_data = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c9e65bd93a..333b8d1733 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -348,6 +348,15 @@ def clear(self): self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose_invalid = True + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + return [self], None + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6e3835fbef..940f2ae46f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -285,6 +285,15 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + return [self], None + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From 4b523d298cb0ecbd2015c557346c33928569a268 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 28 Feb 2025 16:29:14 -0800 Subject: [PATCH 134/239] [PyTorch] Set flags in norm modules for Mcore sequence-parallel support (#1528) Set flag in norm modules for Mcore sequence-parallel support Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm.py | 3 +++ transformer_engine/pytorch/module/rmsnorm.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 1a635afbb8..61aa69818a 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -104,6 +104,9 @@ def __init__( # Flag for sequence parallelism (custom Megatron-LM integration) self.sequence_parallel: Optional[bool] = sequence_parallel + if sequence_parallel is not None: + self.weight.sequence_parallel = sequence_parallel + self.bias.sequence_parallel = sequence_parallel def reset_layer_norm_parameters(self) -> None: """Init LN params""" diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index d2e0d1b2ba..bc826edc2a 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -108,6 +108,8 @@ def __init__( # Flag for sequence parallelism (custom Megatron-LM integration) self.sequence_parallel: Optional[bool] = sequence_parallel + if sequence_parallel is not None: + self.weight.sequence_parallel = sequence_parallel def reset_rms_norm_parameters(self) -> None: """Deprecated""" From d04d80590f55a8d1f69310b1044525c37297002c Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 28 Feb 2025 17:59:30 -0800 Subject: [PATCH 135/239] WIP: switch to flash_attn_varlen_func Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 34 +- transformer_engine/pytorch/attention.py | 508 ++++++++++-------- transformer_engine/pytorch/csrc/extensions.h | 6 +- .../pytorch/csrc/extensions/attention.cu | 101 ++-- .../pytorch/csrc/extensions/pybind.cpp | 4 +- transformer_engine/pytorch/csrc/kv_cache.cuh | 46 +- transformer_engine/pytorch/inference.py | 142 +++-- 7 files changed, 453 insertions(+), 388 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 620710f393..f33c92ae2a 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -90,6 +90,8 @@ def __init__( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") + #self.context_lens[0] = 2 + #self.context_lens[2] = 3 # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -108,6 +110,7 @@ def __init__( self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to( dtype=torch.int32, device="cpu" ) + #self.arrival_times[2] = 0 # self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() @@ -207,9 +210,9 @@ def step(self, dynamic_fill: bool = True): def get_model( module: torch.nn.Module, - backend: str = "FusedAttention", config: ModelConfig, dtype: torch.dtype, + backend: str = "FusedAttention", qkv_format: str = "bshd", num_layers: int = 1, mode: str = "reference", @@ -318,7 +321,7 @@ def generate_args( for i in range(num_tensors) ] -def get_tols(module, backend, dtype) +def get_tols(module, backend, dtype): if module == "TransformerLayer": if backend != "FlashAttention": tols = { @@ -329,7 +332,7 @@ def get_tols(module, backend, dtype) else: tols = { torch.float32: 1e-3, - torch.half: 4e-3, + torch.half: 3e-3, torch.bfloat16: 1e-2, } if module == "DotProductAttention": @@ -342,7 +345,7 @@ def get_tols(module, backend, dtype) else: tols = { torch.float32: 1e-3, - torch.half: 4e-3, + torch.half: 1e-3, torch.bfloat16: 1e-2, } return tols[dtype] @@ -351,8 +354,8 @@ def get_tols(module, backend, dtype) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("module", ["TransformerLayer"]) # , "DotProductAttention"]) +@pytest.mark.parametrize("backend", ["FlashAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["DotProductAttention"])#TransformerLayer"]) # , "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): logger = logging.getLogger("test_paged_attn") @@ -389,7 +392,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config.max_seqlen_kv = 256 # create full model - model = get_model(module, backend, config, dtype, qkv_format, num_layers, mode="reference") + model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference") # generate data for all requests assert ( @@ -413,6 +416,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda *full_output if isinstance(full_output, List) else full_output, # rotary_pos_emb=rotary_freqs, ) + print("full", full_output[0,:2,:8]) + print("full", full_output[1,:7,:8]) + print("full", full_output[2,:3,:8]) # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -449,13 +455,13 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda head_dim_q=config.head_dim_qk, max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, - allow_query_conversion=backend != "FusedAttention", + #allow_query_conversion=backend != "FusedAttention", ) for layer_number in range(1, num_layers + 1): inference_params.allocate_memory(layer_number, qkv_format) # create inference model - model = get_model(module, backend, config, dtype, qkv_format, num_layers, mode="inference") + model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="inference") # graph the model if necessary if is_cuda_graph: @@ -527,6 +533,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # create incremental input batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() + num_tensors = len(full_inputs) if qkv_format == "thd": incremental_inputs = [] for i in range(num_tensors): @@ -613,12 +620,13 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # compare results tol = get_tols(module, backend, dtype) for i, seq in enumerate(sim.t_seq_ids): - token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 + #token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 + token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": print(i, seq, sim.t_total_lens, sim.step_lens, token_index) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(incremental_output[i, token_index, :4]) - print(incremental_output[i, sim.step_lens[i] - 1, :4]) + #print(incremental_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], # incremental_output[:sim.step_lens[i] - 1, i, :], @@ -628,7 +636,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda rtol=tol, ) if qkv_format == "sbhd": - print(i, seq, sim.t_total_lens, sim.step_lens) + print(i, seq, sim.t_total_lens, sim.step_lens, token_index) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(incremental_output[token_index, i, :4]) torch.testing.assert_close( @@ -653,6 +661,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ) sim.t += 1 sim.t_gen_lens = sim.t_gen_lens + 1 + #if sim.t == 1: + # break sim.serving_times = sim.arrival_times + sim.request_delays sim.complete_times = sim.serving_times + sim.gen_lens diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8f507aee65..00c84b954a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5479,6 +5479,7 @@ def get_qkv_layout( k: torch.Tensor, v: torch.Tensor, qkv_format: str = "sbhd", + inference_params: InferenceParams = None, ) -> str: """Get qkv layout. @@ -5495,6 +5496,8 @@ def get_qkv_layout( the sequence length dimension, `b` batch size, `h` the number of attention heads, `d` head size, and `t` the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. + inference_params: InferenceParams, default = `None` + InferenceParams related to KV caching. Returns ---------- @@ -5636,6 +5639,9 @@ def run_iteratively(q, k, v): if qkv_layout == "not_supported": raise RuntimeError("The provided qkv memory layout is not supported!") + if inference_params is not None and inference_params.is_paged: + qkv_layout = "paged_kv_" + qkv_layout + return qkv_layout, q, k, v, q_format, kv_format @@ -5762,9 +5768,27 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) + # get q_format and kv_format for training and inference + if inference_params is not None: #"_2" in qkv_layout: + #qkv_format = qkv_layout.replace("paged_kv_", "") + #q_format, kv_format = qkv_format.split("_2") + q_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) + kv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] + ) + qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format + else: + qkv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) + q_format = qkv_format + kv_format = qkv_format + + print('FA 0', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout) + # convert q, k, v to bshd if they are in sbhd + # qkv_format is unchanged if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if qkv_format == "sbhd": # For now just 128, will make it more general in the future @@ -5778,8 +5802,10 @@ def forward( ) else: query_layer, key_layer, value_layer = [ - x.transpose(0, 1) for x in (query_layer, key_layer, value_layer) + x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer) ] + elif q_format == "sbhd" and kv_format == "bshd": + query_layer = query_layer.transpose(0, 1).contiguous() if context_parallel: query_layer, key_layer, value_layer = [ x.contiguous() for x in (query_layer, key_layer, value_layer) @@ -5787,31 +5813,34 @@ def forward( else: if qkv_format == "sbhd": query_layer._data, key_layer._data, value_layer._data = [ - x.transpose(0, 1) + x.transpose(0, 1).contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] query_layer, key_layer, value_layer = [ Float8Tensor.make_like(x, data=x._data, shape=x._data.shape) for x in (query_layer, key_layer, value_layer) ] + elif q_format == "sbhd" and kv_format == "bshd": + query_layer._data = query_layer._data.transpose(0, 1).contiguous() + query_layer = Float8Tensor.make_like(query_layer, data=query_layer._data, shape=query_layer._data.shape) if context_parallel: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - batch_size = query_layer.shape[0] - - if qkv_format in ["sbhd", "bshd"]: - max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] - max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size + print('FA 1', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout) + # get accurate batch_size, max_seqlen and cu_seqlens + batch_size = None + if inference_params is None: + if qkv_format in ["sbhd", "bshd"]: + batch_size = query_layer.shape[0] + max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size - if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" - cu_seqlens_q = cu_seqlens_q[: batch_size + 1] - cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] + if "padding" in attn_mask_type: + assert not context_parallel, "Padding mask not supported with context parallelism!" - if inference_params is None: # [b * s, h, d] query_layer, key_layer, value_layer = [ x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) @@ -5849,30 +5878,58 @@ def forward( key_layer, value_layer = PackTensors.apply( indices_kv, key_layer, value_layer ) - else: - # Cumulative sequence lengths for unpadded data - if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( + else: + # Cumulative sequence lengths for unpadded data + if cu_seqlens_q is None: + cu_seqlens_q = _get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, + ) + if cu_seqlens_kv is None: + cu_seqlens_kv = _get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + if qkv_format == "thd": + assert ( + cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + if max_seqlen_q is None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_q = seqlens_q.max().item() + if max_seqlen_kv is None: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_kv = seqlens_kv.max().item() + else: + if qkv_format in ["sbhd_2bshd", "bshd"]: + # q is in bshd in both cases (conversion above or original input) + batch_size, context_len, num_heads, head_dim = query_layer.shape + cu_seqlens_q = cu_seqlens_q[: batch_size + 1] + cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] + # convert to thd_2bshd + if isinstance(query_layer, Float8Tensor): + query_layer._data = tex.convert_bshd_to_thd( + query_layer._data, + cu_seqlens_q, batch_size, - max_seqlen_q, - query_layer.device, - ) - if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( + context_len, + num_heads, + head_dim, + batch_size * context_len, + ) + query_layer = Float8Tensor.make_like(query_layer, data=query_layer._data, shape=query_layer._data.shape) + else: + query_layer = tex.convert_bshd_to_thd( + query_layer, + cu_seqlens_q, batch_size, - max_seqlen_kv, - key_layer.device, - ) - elif qkv_format == "thd": - assert ( - cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" - if max_seqlen_q is None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = seqlens_q.max().item() - if max_seqlen_kv is None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = seqlens_kv.max().item() + context_len, + num_heads, + head_dim, + batch_size * context_len, + ) if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] @@ -5923,126 +5980,139 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] - if inference_params is not None: - if _flash_attn_2_2_plus: - func = flash_attn_with_kvcache - if _use_flash_attn_3: - func = flash_attn_with_kvcache_v3 - fa_optional_forward_kwargs_kvcache = {} - cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - fa_optional_forward_kwargs_kvcache["cache_seqlens"] = cache_seqlens - fa_optional_forward_kwargs_kvcache["softmax_scale"] = self.softmax_scale - fa_optional_forward_kwargs_kvcache["causal"] = "causal" in attn_mask_type - if inference_params.is_paged: - fa_optional_forward_kwargs_kvcache["block_table"] = ( + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type and inference_params is None: + func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 + else: + func = ( + flash_attn_varlen_func + if not _use_flash_attn_3 + else flash_attn_varlen_func_v3 + ) + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append(max_seqlen_q) + fa_optional_forward_args_thd.append(max_seqlen_kv) + if inference_params is not None: + fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] + if inference_params.is_paged + else inference_params.cache_manager.batch_indices.unsqueeze(1)[:batch_size] ) - output = func( - query_layer, - key_layer, - value_layer, - **fa_optional_forward_kwargs_kvcache, - ) - else: - if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: - func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 - else: - func = ( - flash_attn_varlen_func - if not _use_flash_attn_3 - else flash_attn_varlen_func_v3 + if _use_flash_attn_3: + fa_3_optional_forward_kwargs = {} + fa_3_optional_forward_kwargs["window_size"] = window_size + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic + if inference_params is not None: + fa_3_optional_forward_kwargs["page_table"] = ( + inference_params.cache_manager.page_table[:batch_size] + if inference_params.is_paged + else inference_params.cache_manager.batch_indices ) - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) - fa_optional_forward_args_thd.append(max_seqlen_q) - fa_optional_forward_args_thd.append(max_seqlen_kv) - if _use_flash_attn_3: - fa_3_optional_forward_kwargs = {} - fa_3_optional_forward_kwargs["window_size"] = window_size - fa_3_optional_forward_kwargs["deterministic"] = self.deterministic - if fp8: - QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) - torch_orig_dtype = query_layer.dtype - - def convert_to_torch_float8(tensor, dtype): - out = torch.Tensor().to(device=tensor.device, dtype=dtype) - out.set_( - tensor._data.untyped_storage(), - tensor._data.storage_offset(), - tensor._data.shape, - tensor._data.stride(), - ) - return out - - # "fp8_mha" decides outputs in fp8, while inputs are inferred from - # the real dtype - assert isinstance(key_layer, query_layer.__class__) and isinstance( - value_layer, query_layer.__class__ - ), "q, k, and v must have the same type." - if not isinstance(query_layer, Float8Tensor): - query_layer, key_layer, value_layer = ( - QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] - ) - fa_3_optional_forward_kwargs["descale_q"] = ( - query_layer._scale_inv.unsqueeze(0) - ) - fa_3_optional_forward_kwargs["descale_k"] = ( - key_layer._scale_inv.unsqueeze(0) - ) - fa_3_optional_forward_kwargs["descale_v"] = ( - value_layer._scale_inv.unsqueeze(0) + if fp8: + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] + torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + torch_orig_dtype = query_layer.dtype + + def convert_to_torch_float8(tensor, dtype): + out = torch.Tensor().to(device=tensor.device, dtype=dtype) + out.set_( + tensor._data.untyped_storage(), + tensor._data.storage_offset(), + tensor._data.shape, + tensor._data.stride(), ) + return out + + # "fp8_mha" decides outputs in fp8, while inputs are inferred from + # the real dtype + assert isinstance(key_layer, query_layer.__class__) and isinstance( + value_layer, query_layer.__class__ + ), "q, k, and v must have the same type." + if not isinstance(query_layer, Float8Tensor): query_layer, key_layer, value_layer = ( - convert_to_torch_float8(x, torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) - try: - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_3_optional_forward_kwargs, + QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] ) - except TypeError as e: - if _flash_attn_3_0_0_beta: - e.args = ( - e.args[0] - + ". Please update your flash-attn v3 (beta) installation" - " as it " - + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, - ) + e.args[1:] - raise - - if fp8: - output = output.to(dtype=torch_orig_dtype) - if fp8 and fp8_meta["recipe"].fp8_mha: - O_quantizer = quantizers["scaling_fwd"][META_O] - output = O_quantizer(output) - else: - output = func( + fa_3_optional_forward_kwargs["descale_q"] = ( + query_layer._scale_inv.unsqueeze(0) + ) + fa_3_optional_forward_kwargs["descale_k"] = ( + key_layer._scale_inv.unsqueeze(0) + ) + fa_3_optional_forward_kwargs["descale_v"] = ( + value_layer._scale_inv.unsqueeze(0) + ) + query_layer, key_layer, value_layer = ( + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] + ) + try: + output, _ = func( query_layer, key_layer, value_layer, *fa_optional_forward_args_thd, - self.attention_dropout if self.training else 0.0, softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, + **fa_3_optional_forward_kwargs, ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your flash-attn v3 (beta) installation" + " as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise - if ( - qkv_format in ["sbhd", "bshd"] - and "padding" in attn_mask_type - and inference_params is None - ): - output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + if fp8: + output = output.to(dtype=torch_orig_dtype) + if fp8 and fp8_meta["recipe"].fp8_mha: + O_quantizer = quantizers["scaling_fwd"][META_O] + output = O_quantizer(output) + else: + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, + ) - if qkv_format == "sbhd": + if inference_params is None: + if ( + qkv_format in ["sbhd", "bshd"] + and "padding" in attn_mask_type + ): + output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + elif qkv_format in ["bshd", "sbhd_2bshd"]: + # convert back to bshd_2bshd from thd_2bshd + #batch_size, context_len, num_heads, head_dim = output.shape + if isinstance(query_layer, Float8Tensor): + output._data = tex.convert_thd_to_bshd( + output._data, + cu_seqlens_q, + batch_size, + context_len, + num_heads, + head_dim, + ) + output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape) + else: + output = tex.convert_thd_to_bshd( + output, + cu_seqlens_q, + batch_size, + context_len, + num_heads, + head_dim, + ) + + if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) if fp8 and fp8_meta["recipe"].fp8_mha: output_data = ( @@ -6057,10 +6127,10 @@ def convert_to_torch_float8(tensor, dtype): ) else: output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) - elif qkv_format == "bshd": + elif q_format == "bshd": # (bs)hd -> bs(hd) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) - elif qkv_format == "thd": + elif q_format == "thd": # thd -> t(hd) output = output.reshape(output.shape[0], -1) @@ -7375,6 +7445,16 @@ def forward( num_gemms=3, allow_non_contiguous=True, ) as query_layer: + # checks for RNG + if self.rng_states_tracker is not None and is_graph_capturing(): + assert isinstance( + self.rng_states_tracker, CudaRNGStatesTracker + ), "Unsupported RNG states tracker." + assert ( + graph_safe_rng_available() + ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + + # checks for FP8 if self.fp8: if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: @@ -7383,7 +7463,6 @@ def forward( """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """fp8_meta["recipe"].fp8_mha=True""" ) - if self.fp8 and self.fp8_meta["recipe"].fp8_dpa: forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False) @@ -7395,6 +7474,7 @@ def forward( tex.DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + # checks for q/k/v shapes assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "DotProductAttention only supports CUDA tensors." @@ -7404,31 +7484,28 @@ def forward( assert ( key_layer.shape[:-1] == value_layer.shape[:-1] ), "Keys and values must have the same batch size, sequence length and number of heads!" + num_attention_heads = query_layer.shape[-2] + num_gqa_groups = key_layer.shape[-2] + assert ( + query_layer.shape[-1] == key_layer.shape[-1] + ), "Queries and keys must have the same head dimension!" + head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1] assert ( - key_layer.shape[-1] == self.hidden_size_per_attention_head_k - ), f"Keys have head_dim = {key_layer.shape[-1]}, " + head_dim_qk == self.hidden_size_per_attention_head_k + ), f"Keys have head_dim = {head_dim_qk}, " "but expected head_dim = {self.hidden_size_per_attention_head_k}!" assert ( - value_layer.shape[-1] == self.hidden_size_per_attention_head_v - ), f"Values have head_dim = {value_layer.shape[-1]}, " + head_dim_v == self.hidden_size_per_attention_head_v + ), f"Values have head_dim = {head_dim_v}, " "but expected head_dim = {self.hidden_size_per_attention_head_v}!" assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - and value_layer.shape[-2] == self.num_gqa_groups_per_partition + num_gqa_groups == self.num_gqa_groups_per_partition ), ( "Keys and values must have num_gqa_group =" - f" {self.num_gqa_groups_per_partition} heads! Found {key_layer.shape[-2]} in" - f" key_layer and {value_layer.shape[-2]} in value_layer." + f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}." ) - if qkv_format is None: - qkv_format = self.qkv_format - assert qkv_format in [ - "sbhd", - "bshd", - "thd", - ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" - + # checks for attention mask if attn_mask_type is None: attn_mask_type = self.attn_mask_type else: @@ -7439,18 +7516,31 @@ def forward( attn_mask_type in AttnMaskTypes ), f"Attention mask type {attn_mask_type} is not supported!" + # checks for sliding window if window_size is None: window_size = self.window_size window_size = check_set_window_size(attn_mask_type, window_size) - if self.rng_states_tracker is not None and is_graph_capturing(): - assert isinstance( - self.rng_states_tracker, CudaRNGStatesTracker - ), "Unsupported RNG states tracker." - assert ( - graph_safe_rng_available() - ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." - + # checks for qkv_format + if qkv_format is None: + qkv_format = self.qkv_format + assert qkv_format in [ + "sbhd", + "bshd", + "thd", + ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" + if qkv_format in ["sbhd", "bshd"]: + assert all( + len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) + ), f"Queries, keys and values must be 4D tensors when {qkv_format=}!" + if qkv_format == "sbhd": + batch_size = query_layer.shape[1] + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + else: + batch_size = query_layer.shape[0] + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv if qkv_format == "thd": assert all( len(x.shape) == 3 for x in (query_layer, key_layer, value_layer) @@ -7483,54 +7573,39 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - if qkv_format in ["sbhd", "bshd"]: - assert all( - len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) - ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" - if qkv_format == "sbhd": - max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[1] - else: - max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[0] - + # retrieve tokens from KV cache in inference page_table = None if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" - # convert causal to causal_bottom_right in inference when KV-caching is in use - # so users can run with the same attn_mask_type for training and inference - if attn_mask_type in ["causal", "padding_causal"]: + assert "padding" in attn_mask_type, "KV caching requires padding mask!" + if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" - # convert to cross attention type when KV cache is in use self.attention_type = "cross" self.flash_attention.attention_type = self.attention_type self.fused_attention.attention_type = self.attention_type self.unfused_attention.attention_type = self.attention_type - # force tensors to be contiguous if not already query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x for x in [query_layer, key_layer, value_layer] ] - # update KV cache and return the full key/value tensors + # update KV cache and retrieve full KV tokens ( - query_layer, + #query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, - max_seqlen_q, + #max_seqlen_q, max_seqlen_kv, qkv_format, ) = inference_params.step( self.layer_number, - query_layer, + #query_layer, key_layer, value_layer, qkv_format, @@ -7538,11 +7613,9 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None - if ( - isinstance(query_layer, Float8Tensor) - and isinstance(key_layer, Float8Tensor) - and isinstance(value_layer, Float8Tensor) - ): + print('FA DPA', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format)#, qkv_layout) + # get accurate qkv_layout + if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( qkv_layout, query_layer._data, @@ -7551,16 +7624,21 @@ def forward( q_format, kv_format, ) = get_qkv_layout( - query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format + query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format, inference_params=inference_params, ) else: - qkv_layout, query_layer, key_layer, value_layer, q_format, kv_format = ( - get_qkv_layout(query_layer, key_layer, value_layer, qkv_format=qkv_format) + ( + qkv_layout, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + ) = get_qkv_layout( + query_layer, key_layer, value_layer, qkv_format=qkv_format, inference_params=inference_params, ) - # convert qkv layout to its corresponding paged attention layout - if inference_params is not None and inference_params.is_paged: - qkv_layout = "paged_kv_" + qkv_layout + # adjust max_seqlen and cu_seqlens cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -7568,7 +7646,6 @@ def forward( for group in self.cp_group: cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - if q_format in ["sbhd", "bshd"]: max_seqlen_q *= cp_size if cu_seqlens_q is None: @@ -7604,6 +7681,7 @@ def forward( key_layer.device, ) + # set ALiBi attributes global _alibi_cache if alibi_slopes is not None: assert ( @@ -7627,6 +7705,7 @@ def forward( _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True + # detect bias shape core_attention_bias_shape = None if core_attention_bias is not None: if ( @@ -7650,6 +7729,7 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + # check if padding between sequences for qkv_format = thd pad_between_seqs = ( cu_seqlens_q_padded is not None and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) @@ -7658,17 +7738,18 @@ def forward( and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) ) + # gather attention params for get available attention backends attention_params = AttentionParams( qkv_type=type(query_layer), qkv_dtype=query_layer.dtype, qkv_layout=qkv_layout, batch_size=batch_size, - num_heads=query_layer.shape[-2], - num_gqa_groups=key_layer.shape[-2], + num_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - head_dim_qk=query_layer.shape[-1], - head_dim_v=value_layer.shape[-1], + head_dim_qk=head_dim_qk, + head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, @@ -7720,9 +7801,11 @@ def forward( fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] + # if no backend is available, raise exception if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: raise ValueError("No dot product attention support for the provided inputs!") + # run attention output = None if use_flash_attention: if core_attention_bias_type == "alibi": @@ -7870,9 +7953,12 @@ def forward( inference_params=inference_params, ) - if inference_params is not None: - inference_params.is_output_right_aligned = use_flash_attention - output = inference_params.post_step(self.layer_number, output) + #if inference_params is not None: + ## inference_params.is_output_right_aligned = use_flash_attention + # output = inference_params.post_step(self.layer_number, output) + #print(output[0,-2:,:8]) + #print(output[1,:,:8]) + #print(output[2,-3:,:8]) return output diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3f25b95356..dd069542d9 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -70,10 +70,8 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, - int d_q, int b, int max_seq_len); -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned); +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d); +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d, int t); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index ab34e35d22..b43be4c267 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1030,85 +1030,78 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t } /*************************************************************************************************** - * KV Cache: Reshape Q from qkv_format = thd to qkv_format = bshd + * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd **************************************************************************************************/ template -void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, - int h_q, int d_q, int b, int max_seq_len) { +void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d) { transformer_engine::fused_attn:: - reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(new_q.data_ptr()), - reinterpret_cast(q_buffer.data_ptr()), cu_new_lens.data_ptr(), - h_q, d_q, b, max_seq_len); + convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(tensor.data_ptr()), + reinterpret_cast(new_tensor.data_ptr()), cu_seqlens.data_ptr(), + b, max_seq_len, h, d); } -void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens, int h_q, - int d_q, int b, int max_seq_len) { - NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(), - "new_q and q_buffer must be of the same data type."); - if (q_buffer.scalar_type() == at::ScalarType::Half) { +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d) { + std::vector shape = {b, max_seq_len, h, d}; + at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); + if (new_tensor.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); - } else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) { + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); - } else if (q_buffer.scalar_type() == at::ScalarType::Float) { + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); - } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) { + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) { using dtype = at::Float8_e4m3fn; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); - } else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) { + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) { using dtype = at::Float8_e5m2; - reshape_q_launcher(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_seq_len); + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } + return new_tensor; } /*************************************************************************************************** - * KV Cache: Reshape O from qkv_format = bshd to qkv_format = thd + * KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd **************************************************************************************************/ template -void reshape_o_launcher(torch::Tensor output, torch::Tensor output_buffer, - torch::Tensor cu_new_lens, int h_o, int d_o, int b, int max_seq_len, - bool is_output_right_aligned) { +void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, + at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d) { transformer_engine::fused_attn:: - reshape_o_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(output.data_ptr()), - reinterpret_cast(output_buffer.data_ptr()), - cu_new_lens.data_ptr(), h_o, d_o, b, max_seq_len, is_output_right_aligned); + convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(tensor.data_ptr()), + reinterpret_cast(new_tensor.data_ptr()), + cu_seqlens.data_ptr(), b, max_seq_len, h, d); } -void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, - int h_o, int d_o, int b, int max_seq_len, bool is_output_right_aligned) { - NVTE_CHECK(output.scalar_type() == output_buffer.scalar_type(), - "output and output_buffer must be of the same data type."); - if (output.scalar_type() == at::ScalarType::Half) { +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d, int t) { + std::vector shape = {t, h, d}; + at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); + if (tensor.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - } else if (output.scalar_type() == at::ScalarType::BFloat16) { + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - } else if (output.scalar_type() == at::ScalarType::Float) { + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::Float) { using dtype = float; - reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - } else if (output.scalar_type() == at::ScalarType::Float8_e4m3fn) { + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) { using dtype = at::Float8_e4m3fn; - reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); - } else if (output.scalar_type() == at::ScalarType::Float8_e5m2) { + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) { using dtype = at::Float8_e5m2; - reshape_o_launcher(output, output_buffer, cu_new_lens, h_o, d_o, b, max_seq_len, - is_output_right_aligned); + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); } else { NVTE_ERROR("Unsupported dtype for KV cache.\n"); } + return new_tensor; } /*************************************************************************************************** @@ -1125,9 +1118,9 @@ void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor **************************************************************************************************/ template -void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, - torch::Tensor v_cache, torch::Tensor page_table, - torch::Tensor cu_new_lens, torch::Tensor cu_cached_lens, +void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, + at::Tensor v_cache, at::Tensor page_table, + at::Tensor cu_new_lens, at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged) { @@ -1152,9 +1145,9 @@ void copy_to_kv_cache_launcher(torch::Tensor new_k, torch::Tensor new_v, torch:: } } -void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, - torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, - torch::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, +void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, + at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens, + at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged) { NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 93a86deabd..94e1948bb6 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -191,8 +191,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_attn_bwd", &fused_attn_bwd, "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("copy_to_kv_cache", ©_to_kv_cache, "Copy new KV tokens to KV cache"); - m.def("reshape_q", &reshape_q, "Reshape Q for THD before attention"); - m.def("reshape_o", &reshape_o, "Reshape O for THD after attention"); + m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD"); + m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD"); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index 6209db18e4..0cdf043d1e 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -9,40 +9,31 @@ namespace transformer_engine { namespace fused_attn { template -__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens, int h_q, - int d_q, int b, int max_seq_len) { - // new_q: thd; q_buffer: bshd; - // cu_new_lens: [b + 1] +__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, int b, int max_seq_len, int h, int d) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int num_elts = (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]) * h_q * d_q; - int new_token_offset = cu_new_lens[batch_idx] * h_q * d_q; - int cache_offset = batch_idx * max_seq_len * h_q * d_q; - scalar_t *new_q_token = new_q + new_token_offset; - scalar_t *q_buffer_token = q_buffer + cache_offset; + int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d; + int thd_offset = cu_seqlens[batch_idx] * h * d; + int bshd_offset = batch_idx * max_seq_len * h * d; + scalar_t *thd_token = tensor + thd_offset; + scalar_t *bshd_token = new_tensor + bshd_offset; for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(q_buffer_token + i) = *(new_q_token + i); + *(bshd_token + i) = *(thd_token + i); } } } template -__global__ void reshape_o_kernel(scalar_t *output, scalar_t *output_buffer, int *cu_new_lens, - int h_o, int d_o, int b, int max_seq_len, - bool is_output_right_aligned) { - // output: bshd; output_buffer: thd; - // cu_new_lens: [b + 1] +__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, + int b, int max_seq_len, int h, int d) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; - int num_elts = new_len * h_o * d_o; - int output_offset = batch_idx * max_seq_len * h_o * d_o; - if (is_output_right_aligned) { - output_offset = ((batch_idx + 1) * max_seq_len - new_len) * h_o * d_o; - } - int output_buffer_offset = cu_new_lens[batch_idx] * h_o * d_o; - scalar_t *output_token = output + output_offset; - scalar_t *output_buffer_token = output_buffer + output_buffer_offset; + int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]; + int num_elts = seqlen * h * d; + int bshd_offset = batch_idx * max_seq_len * h * d; + int thd_offset = cu_seqlens[batch_idx] * h * d; + scalar_t *bshd_token = tensor + bshd_offset; + scalar_t *thd_token = new_tensor + thd_offset; for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { - *(output_buffer_token + i) = *(output_token + i); + *(thd_token + i) = *(bshd_token + i); } } } @@ -77,6 +68,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } } + if (blockIdx.x == 0) { + for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { + batch_indices[batch_idx] = batch_idx; + } + } } template diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 67dec8c376..ac0d83b18d 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -75,10 +75,7 @@ class InferenceParams: attn_mask_type="padding_causal", ) # assume qkv_format = "bshd" - if inference_params.is_output_right_aligned: - output = output[:,-1] - else: - output = output[:,step_dict.values()] + output = output[:,step_dict.values()] The memory allocation and copies of the new KV tokens to KV cache take place @@ -137,18 +134,6 @@ class DotProductAttention: Format of the incoming query/key/value tensors in current iteration cache_manager: KVCacheManager, default = None Custom cache manager, with KVCacheManager as the base class. - allow_query_conversion: bool, default = True - InferenceParams only supports cache_qkv_format = 'bshd'. When qkv_format = {'sbhd', 'thd'}, - output_qkv_format = {'sbhd_2bshd', 'thd_2bshd'}, which are supported by FusedAttention but - not by FlashAttention or UnfusedDotProductAttention. - - For performance, try allow_query_conversion = False first. If it errors out with "No dot - product attention support for the provided inputs!", set allow_query_conversion = True. - - For functionality, set allow_query_conversion = True. InferenceParams converts query from - {'sbhd', 'thd'} to 'bshd', and converts the output back to {'sbhd', 'thd'}. The cost is - two transposes for qkv_format = 'sbhd', and one memory buffer (q_buffer) and two conversion - kernels (reshape_q and reshape_o) for qkv_format = 'thd'. """ def __init__( @@ -167,7 +152,7 @@ def __init__( max_ctx_len: int = None, qkv_format: str = "bshd", cache_manager: KVCacheManager = None, - allow_query_conversion: bool = True, + #allow_query_conversion: bool = True, ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv @@ -179,9 +164,9 @@ def __init__( _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - self.allow_query_conversion = allow_query_conversion and ( - _NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN - ) + #self.allow_query_conversion = allow_query_conversion and ( + # _NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN + #) if not self.is_paged: cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager @@ -221,20 +206,18 @@ def __init__( if qkv_format == "thd": assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" self.max_ctx_len = max_ctx_len - if self.allow_query_conversion: - # query is converted to 'bshd' for certain backends - assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" - assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" - self.num_heads_q = num_heads_q - self.head_dim_q = head_dim_q - self.max_seqlen_q = max_ctx_len - self.q_orig = {} - self.q_buffer = {} + #if self.allow_query_conversion: + # # query is converted to 'bshd' for certain backends + # assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" + # assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" + # self.num_heads_q = num_heads_q + # self.head_dim_q = head_dim_q + # self.max_seqlen_q = max_ctx_len # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format or self.allow_query_conversion: + if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: self.output_qkv_format = self.cache_qkv_format else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format @@ -247,15 +230,13 @@ def __init__( self.cu_seqlens_q = None self.cu_seqlens_kv = None - self.is_output_right_aligned = False - def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() self.cache_manager.reset() - if self.input_qkv_format == "thd" and self.allow_query_conversion: - for _, q_buffer in self.q_buffer.items(): - q_buffer.fill_(0) + #if self.input_qkv_format == "thd" and self.allow_query_conversion: + # for _, q_buffer in self.q_buffer.items(): + # q_buffer.fill_(0) def __repr__(self) -> str: if self.is_paged: @@ -301,15 +282,15 @@ def allocate_memory(self, layer_number: int, qkv_format: str): device=torch.cuda.current_device(), ) - if qkv_format == "thd" and self.allow_query_conversion: - self.q_buffer[layer_number] = torch.zeros( - self.max_batch_size, - self.max_ctx_len, - self.num_heads_q, - self.head_dim_q, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) +# if qkv_format == "thd" and self.allow_query_conversion: +# self.q_buffer[layer_number] = torch.zeros( +# self.max_batch_size, +# self.max_ctx_len, +# self.num_heads_q, +# self.head_dim_q, +# dtype=self.dtype, +# device=torch.cuda.current_device(), +# ) def pre_step( self, @@ -318,6 +299,7 @@ def pre_step( """Update tracked sequences and prepare for step()""" self.step_dict = step_dict self.batch_size = len(step_dict) + self.total_tokens = sum(step_dict.values()) self.sequences = self.cache_manager.pre_step(step_dict) self.sequences_pre = OrderedDict() @@ -387,7 +369,7 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): def step( self, layer_number: int, - new_q: torch.Tensor, + #new_q: torch.Tensor, new_k: torch.Tensor, new_v: torch.Tensor, qkv_format: str, @@ -399,8 +381,6 @@ def step( ---------- layer_number: int Layer number of attention in the model - new_q: torch.Tensor - New query tokens for layer_number in current inference iteration new_k: torch.Tensor New key tokens for layer_number in current inference iteration new_v: torch.Tensor @@ -410,8 +390,6 @@ def step( Returns ------- - q_buffer: torch.Tensor - new_q reshaped in order to allow certain backends to execute k_cache: torch.Tensor Full key tensor containing both previous and current key tokens v_cache: torch.Tensor @@ -429,34 +407,35 @@ def step( qkv_format: str Updated qkv_format, e.g. the input 'thd' format may become 'thd_2bshd' after step() """ + print('self.sequences', self.sequences) self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format or self.allow_query_conversion: + if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: self.output_qkv_format = self.cache_qkv_format else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - q_buffer = new_q - if qkv_format == "bshd": - self.max_seqlen_q = new_q.shape[1] - q_buffer = new_q.contiguous() - if qkv_format == "sbhd": - self.max_seqlen_q = new_q.shape[0] - if self.allow_query_conversion: - q_buffer = new_q.transpose(0, 1).contiguous() - if qkv_format == "thd": - self.max_seqlen_q = self.max_ctx_len - if self.allow_query_conversion: - q_buffer = self.q_buffer[layer_number] - tex.reshape_q( - new_q, - self.q_buffer[layer_number], - self.cu_seqlens_q, - self.num_heads_q, - self.head_dim_q, - self.max_batch_size, - self.max_ctx_len, - ) - self.q_orig[layer_number] = new_q + #q_buffer = new_q + #if qkv_format == "bshd": + # self.max_seqlen_q = new_q.shape[1] + # q_buffer = new_q.contiguous() + #if qkv_format == "sbhd": + # self.max_seqlen_q = new_q.shape[0] + # if self.allow_query_conversion: + # q_buffer = new_q.transpose(0, 1).contiguous() + #if qkv_format == "thd": + # self.max_seqlen_q = self.max_ctx_len + # if self.allow_query_conversion: + # q_buffer = self.q_buffer[layer_number] + # tex.convert_thd_to_bshd( + # new_q, + # self.q_buffer[layer_number], + # self.cu_seqlens_q, + # self.max_batch_size, + # self.max_ctx_len, + # self.num_heads_q, + # self.head_dim_q, + # ) + # self.q_orig[layer_number] = new_q k_cache, v_cache, page_table = self.cache_manager.step( layer_number, @@ -468,13 +447,13 @@ def step( ) return ( - q_buffer, + #q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, - self.max_seqlen_q, + #self.max_seqlen_q, self.max_seqlen_kv, self.output_qkv_format, ) @@ -487,21 +466,22 @@ def post_step( """ Process the attention output in order to return it to the original qkv_format. """ + print('post step ',self.input_qkv_format) if self.input_qkv_format == "bshd": output = output[: self.batch_size, : self.max_seqlen_q].contiguous() - if self.input_qkv_format == "sbhd" and self.allow_query_conversion: + if self.input_qkv_format == "sbhd": # and self.allow_query_conversion: output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() - if self.input_qkv_format == "thd" and self.allow_query_conversion: + if self.input_qkv_format == "thd": # and self.allow_query_conversion: output_buffer = self.q_orig[layer_number] - tex.reshape_o( + tex.convert_bshd_to_thd( output, output_buffer, self.cu_seqlens_q, - self.num_heads_q, - self.head_dim_q, self.batch_size, self.max_ctx_len, - self.is_output_right_aligned, + self.num_heads_q, + self.head_dim_q, + self.total_tokens, ) output = output_buffer.view(output_buffer.shape[0], -1) @@ -585,6 +565,7 @@ def pre_step( ) ).to(dtype=torch.int32, device="cpu") ) + print('self.batch_indices', self.batch_indices) # Advance unfinished sequences for i in unfinished_seqs: @@ -648,6 +629,7 @@ def step( batch_size = new_k.shape[1] ctx_len = new_k.shape[0] + #print('non-paged self.batch_indices', self.batch_indices) tex.copy_to_kv_cache( new_k, new_v, From 0cbe998420e6e2722341f65868c178d0389a4874 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 28 Feb 2025 18:43:46 -0800 Subject: [PATCH 136/239] WIP: fix unfused for separate q/kv format Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- transformer_engine/pytorch/attention.py | 150 ++++++++++++++++---- transformer_engine/pytorch/inference.py | 11 +- 3 files changed, 127 insertions(+), 36 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index f33c92ae2a..3102f0dbdb 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -354,7 +354,7 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FlashAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["UnfusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["DotProductAttention"])#TransformerLayer"]) # , "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 00c84b954a..676de3b1ec 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -593,9 +593,10 @@ def get_attention_backend( use_fused_attention = False # Filter: QKV layout - qkv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + #qkv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + #) if qkv_format == "thd": if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") @@ -5270,26 +5271,63 @@ def forward( qkv_layout in QKVLayouts ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" - qkv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) + #qkv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + #) + # get q_format and kv_format for training and inference + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + #if inference_params is not None: #"_2" in qkv_layout: + # #qkv_format = qkv_layout.replace("paged_kv_", "") + # #q_format, kv_format = qkv_format.split("_2") + # q_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + # ) + # kv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] + # ) + # qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format + #else: + # qkv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + # ) + # q_format = qkv_format + # kv_format = qkv_format + if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged( - self.layer_number, inference_params.input_qkv_format - ) + self.layer_number) #, inference_params.input_qkv_format if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + if qkv_format == "sbhd_2bshd": + key_layer, value_layer = [ + x.transpose(0, 1) for x in [key_layer, value_layer] + ] + + total_tokens, batch_size = None, None + if qkv_format == "thd_2bshd": + total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0] + query_layer = tex.convert_thd_to_bshd( + query_layer, + cu_seqlens_q, + batch_size, + inference_params.max_ctx_len, + query_layer.shape[-2], + query_layer.shape[-1], + ) + query_layer, key_layer, value_layer = [ + x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] + ] batch_size, max_seqlen_q, max_seqlen_kv = ( query_layer.shape[1], query_layer.shape[0], key_layer.shape[0], ) - if "padding" in attn_mask_type and qkv_format in ["bshd", "sbhd"]: + if "padding" in attn_mask_type: # and qkv_format in ["bshd", "sbhd"]: attention_mask = get_attn_mask( batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv ) @@ -5422,20 +5460,38 @@ def forward( # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) - if qkv_format == "sbhd": + if q_format == "sbhd": # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] context_layer = context_layer.view(seqlen, batch_size, -1) - if qkv_format == "bshd": + if q_format == "bshd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [b, sq, hp] context_layer = context_layer.view(batch_size, seqlen, -1) + if qkv_format == "thd_2bshd": + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [tq, np, hn] + context_layer = tex.convert_bshd_to_thd( + context_layer, + cu_seqlens_q, + batch_size, + inference_params.max_ctx_len, + context_layer.shape[-2], + context_layer.shape[-1], + total_tokens, + ) + + # [tq, np, hn] --> [tq, hp] + context_layer = context_layer.view(total_tokens, -1) + return context_layer @@ -5473,6 +5529,45 @@ def backward( dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) return dq, dk, dv +def get_qkv_format( + qkv_layout: str = "bshd_bshd_bshd", + inference_params: InferenceParams = None, +) -> str: + """Get qkv layout. + + Parameters + ---------- + qkv_layout: str + Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details. + inference_params: InferenceParams, default = `None` + InferenceParams related to KV caching. + + Returns + ---------- + qkv_format: str, default = `sbhd` + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. + q_format: str + Format of the query tensor, {`bshd`, `sbhd`, `thd`}. + kv_format: str + Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. + """ + if inference_params is not None: #"_2" in qkv_layout: + #qkv_format = qkv_layout.replace("paged_kv_", "") + #q_format, kv_format = qkv_format.split("_2") + q_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) + kv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] + ) + qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format + else: + qkv_format = "".join( + [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + ) + q_format = qkv_format + kv_format = qkv_format + return qkv_format, q_format, kv_format def get_qkv_layout( q: torch.Tensor, @@ -5769,22 +5864,23 @@ def forward( context_parallel = cp_size > 1 # get q_format and kv_format for training and inference - if inference_params is not None: #"_2" in qkv_layout: - #qkv_format = qkv_layout.replace("paged_kv_", "") - #q_format, kv_format = qkv_format.split("_2") - q_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) - kv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] - ) - qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format - else: - qkv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) - q_format = qkv_format - kv_format = qkv_format + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + #if inference_params is not None: #"_2" in qkv_layout: + # #qkv_format = qkv_layout.replace("paged_kv_", "") + # #q_format, kv_format = qkv_format.split("_2") + # q_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + # ) + # kv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] + # ) + # qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format + #else: + # qkv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + # ) + # q_format = qkv_format + # kv_format = qkv_format print('FA 0', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout) # convert q, k, v to bshd if they are in sbhd diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index ac0d83b18d..30f66b6dc0 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -323,7 +323,7 @@ def get_seqlens_pre_step(self): """Get cached sequence lengths for current iteration before adding step_dict.values""" return self.sequences_pre - def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): + def convert_paged_to_nonpaged(self, layer_number: int): #, qkv_format: str): """ Convert k_cache and v_cache from paged to non-paged format. This is used by the UnfusedDotProductAttention backend. Both k_cache and v_cache are assumed to be @@ -333,8 +333,6 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): ---------- layer_number: int Layer number of attention in the model - qkv_format: str - Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Returns ------- @@ -358,11 +356,8 @@ def convert_paged_to_nonpaged(self, layer_number: int, qkv_format: str): b=batch_size, ) - new_k_cache = new_k_cache.contiguous() - new_v_cache = new_v_cache.contiguous() - if qkv_format != "thd": - new_k_cache = new_k_cache[:actual_batch_size] - new_v_cache = new_v_cache[:actual_batch_size] + new_k_cache = new_k_cache.contiguous()[:actual_batch_size] + new_v_cache = new_v_cache.contiguous()[:actual_batch_size] return new_k_cache, new_v_cache From 6091196d889edbbf36f362ab1ea958863c74463e Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 28 Feb 2025 18:56:26 -0800 Subject: [PATCH 137/239] WIP: fix fused for separate q/kv formats Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- transformer_engine/pytorch/attention.py | 103 +++++++++++--------- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 3102f0dbdb..290d5864d8 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -354,7 +354,7 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["UnfusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["DotProductAttention"])#TransformerLayer"]) # , "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 676de3b1ec..48a843bc77 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6811,6 +6811,7 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + inference_params: Optional[InferenceParams] = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -6835,58 +6836,62 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) + #qkv_format = "".join( + # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] + #) + # get q_format and kv_format for training and inference + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) - if qkv_format in ["sbhd", "bshd"]: - if qkv_format == "sbhd": - batch_size = query_layer.shape[1] - max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv - if qkv_format == "bshd": - batch_size = query_layer.shape[0] - max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv - max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size - if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" - - if cu_seqlens_q is None or cu_seqlens_kv is None: - if attention_mask is None: - raise RuntimeError( - "Please provide attention_mask or cu_seqlens for padding!" + if inference_params is None: + if qkv_format in ["sbhd", "bshd"]: + if qkv_format == "sbhd": + batch_size = query_layer.shape[1] + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + if qkv_format == "bshd": + batch_size = query_layer.shape[0] + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size + if "padding" in attn_mask_type: + assert not context_parallel, "Padding mask not supported with context parallelism!" + + if cu_seqlens_q is None or cu_seqlens_kv is None: + if attention_mask is None: + raise RuntimeError( + "Please provide attention_mask or cu_seqlens for padding!" + ) + if self.attention_type == "self": + cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_kv = cu_seqlens_q + else: + cu_seqlens_q = get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + else: + if cu_seqlens_q is None: + cu_seqlens_q = _get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, ) - if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) - cu_seqlens_kv = cu_seqlens_q - else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) - else: - if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( - batch_size, - max_seqlen_q, - query_layer.device, - ) - if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( - batch_size, - max_seqlen_kv, - key_layer.device, - ) - if qkv_format == "thd": - assert ( - max_seqlen_q is not None - and max_seqlen_kv is not None - and cu_seqlens_q is not None - and cu_seqlens_kv is not None - ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" + if cu_seqlens_kv is None: + cu_seqlens_kv = _get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + if qkv_format == "thd": + assert ( + max_seqlen_q is not None + and max_seqlen_kv is not None + and cu_seqlens_q is not None + and cu_seqlens_kv is not None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" - if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None): + if (q_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_q_padded is None: cu_seqlens_q_padded = cu_seqlens_q + if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: cu_seqlens_kv_padded = cu_seqlens_kv qkv_dtype = TE_DType[query_layer.dtype] @@ -7977,6 +7982,7 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + inference_params=inference_params, ) output = self.fused_attention( query_layer, @@ -8005,6 +8011,7 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + inference_params=inference_params, ) from .cpu_offload import CPUOffloadEnabled From 8585bc3087c7fd754a87b5e0329c00646030819f Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 28 Feb 2025 20:03:26 -0800 Subject: [PATCH 138/239] WIP: flash attn + TELayer + 2 layers Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 131 +++++++++---------- transformer_engine/pytorch/attention.py | 4 +- transformer_engine/pytorch/csrc/kv_cache.cuh | 10 +- transformer_engine/pytorch/inference.py | 41 +++--- 4 files changed, 90 insertions(+), 96 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 290d5864d8..40082c4a2c 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -308,58 +308,59 @@ def generate_args( shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk]) shapes.append([*shape, config.num_gqa_groups, config.head_dim_v]) - func = torch.ones if warmup else torch.randn - scale = 1 if warmup else 0.1 num_tensors = len(shapes) - return [ - scale - * func( - *shapes[i], - device="cuda", - dtype=dtype, - ) - for i in range(num_tensors) - ] + if warmup: + return [ + torch.ones( + *shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] + elif module == "TransformerLayer": + return [ + 0.01 * torch.randint( + -100, 100, + shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] + elif module == "DotProductAttention": + return [ + 0.1 * torch.randn( + *shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] def get_tols(module, backend, dtype): if module == "TransformerLayer": - if backend != "FlashAttention": - tols = { - torch.float32: 1e-3, - torch.half: 3e-3, - torch.bfloat16: 1e-2, - } - else: - tols = { - torch.float32: 1e-3, - torch.half: 3e-3, - torch.bfloat16: 1e-2, - } + tols = { + torch.half: 4e-3, + torch.bfloat16: 3e-2, + } if module == "DotProductAttention": - if backend != "FlashAttention": - tols = { - torch.float32: 1e-3, - torch.half: 1e-3, - torch.bfloat16: 1e-2, - } - else: - tols = { - torch.float32: 1e-3, - torch.half: 1e-3, - torch.bfloat16: 1e-2, - } + tols = { + torch.half: 1e-3, + torch.bfloat16: 1e-2, + } return tols[dtype] @pytest.mark.parametrize("dtype", [torch.float16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("module", ["DotProductAttention"])#TransformerLayer"]) # , "DotProductAttention"]) +@pytest.mark.parametrize("backend", ["FlashAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["TransformerLayer"]) # , "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): logger = logging.getLogger("test_paged_attn") - num_layers = 1 # 2 + num_layers = 2 if module == "TransformerLayer" else 1 config = model_configs_infer[model] # figure out supported backends @@ -412,8 +413,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if module == "TransformerLayer": full_output = full_inputs for m in model: + print('xxxxxxxxxxxxxxxxxxxxxxxx ', type(full_output)) full_output = m( - *full_output if isinstance(full_output, List) else full_output, + full_output[0] if isinstance(full_output, List) else full_output, # rotary_pos_emb=rotary_freqs, ) print("full", full_output[0,:2,:8]) @@ -490,13 +492,13 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - model = make_graphed_callables( - model[0], + model = [make_graphed_callables( + model[i], sample_args, num_warmup_iters=10, fp8_enabled=False, sample_kwargs=sample_kwargs, - ) + ) for i in range(num_layers)] sim.reset() inference_params.reset() @@ -586,36 +588,21 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda inference_params.pre_step(step_dict) if inference_params.is_paged: inference_params.cache_manager.print_cache() - if not is_cuda_graph: - incremental_output = incremental_inputs - for m in model: - incremental_output = m( - *( - incremental_output - if isinstance(incremental_output, List) - else incremental_output - ), - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - inference_params=inference_params, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - ) - else: - incremental_output = incremental_inputs - for _ in range(num_layers): - incremental_output = model( - *( - incremental_output - if isinstance(incremental_output, List) - else incremental_output - ), - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - inference_params=inference_params, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - ) + incremental_output = incremental_inputs + for m in model: + print('xxxxdgdg ', type(incremental_output)) + incremental_output = m( + *incremental_output + if isinstance(incremental_output, List) + else incremental_output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + inference_params=inference_params, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + ) + incremental_output = [incremental_output] + incremental_output = incremental_output[0] # compare results tol = get_tols(module, backend, dtype) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 48a843bc77..9fe4d6fd97 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6092,7 +6092,7 @@ def forward( fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged - else inference_params.cache_manager.batch_indices.unsqueeze(1)[:batch_size] + else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[:batch_size] ) if _use_flash_attn_3: fa_3_optional_forward_kwargs = {} @@ -6102,7 +6102,7 @@ def forward( fa_3_optional_forward_kwargs["page_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged - else inference_params.cache_manager.batch_indices + else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[:batch_size] ) if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index 0cdf043d1e..4a693de3f7 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -68,11 +68,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } } - if (blockIdx.x == 0) { - for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { - batch_indices[batch_idx] = batch_idx; - } - } +// if (blockIdx.x == 0) { +// for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { +// batch_indices[batch_idx] = batch_idx; +// } +// } } template diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 30f66b6dc0..1cbe6cbf0b 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -462,23 +462,23 @@ def post_step( Process the attention output in order to return it to the original qkv_format. """ print('post step ',self.input_qkv_format) - if self.input_qkv_format == "bshd": - output = output[: self.batch_size, : self.max_seqlen_q].contiguous() - if self.input_qkv_format == "sbhd": # and self.allow_query_conversion: - output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() - if self.input_qkv_format == "thd": # and self.allow_query_conversion: - output_buffer = self.q_orig[layer_number] - tex.convert_bshd_to_thd( - output, - output_buffer, - self.cu_seqlens_q, - self.batch_size, - self.max_ctx_len, - self.num_heads_q, - self.head_dim_q, - self.total_tokens, - ) - output = output_buffer.view(output_buffer.shape[0], -1) + #if self.input_qkv_format == "bshd": + # output = output[: self.batch_size, : self.max_seqlen_q].contiguous() + #if self.input_qkv_format == "sbhd": # and self.allow_query_conversion: + # output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() + #if self.input_qkv_format == "thd": # and self.allow_query_conversion: + # output_buffer = self.q_orig[layer_number] + # tex.convert_bshd_to_thd( + # output, + # output_buffer, + # self.cu_seqlens_q, + # self.batch_size, + # self.max_ctx_len, + # self.num_heads_q, + # self.head_dim_q, + # self.total_tokens, + # ) + # output = output_buffer.view(output_buffer.shape[0], -1) return output @@ -510,6 +510,7 @@ def __init__( self.cache = {} # track sequence indices in the batch in order to re-index k_cache and v_cache self.batch_indices = None + self.batch_indices_post = None def allocate_memory(self, layer_number): """Allocate memory for the cache""" @@ -536,6 +537,12 @@ def allocate_memory(self, layer_number): dtype=torch.int32, device=torch.cuda.current_device(), ) + self.batch_indices_post = torch.range( + 0, + self.max_batch_size-1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) def pre_step( self, From e05ba5347bd0b760b5180034b3259a0d9f252aa6 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 28 Feb 2025 20:52:10 -0800 Subject: [PATCH 139/239] WIP: unfused + TL + 2layers Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 6 +++--- transformer_engine/pytorch/inference.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 40082c4a2c..a5228e1477 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -226,7 +226,7 @@ def get_model( attn_mask_type = "causal" qkv_format = "bshd" if mode == "inference": - attn_mask_type = "padding_causal" if backend == "FlashAttention" else "padding" + attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding" if module == "TransformerLayer": hidden_size = config.head_dim_qk * config.num_heads @@ -351,11 +351,11 @@ def get_tols(module, backend, dtype): } return tols[dtype] -@pytest.mark.parametrize("dtype", [torch.float16]) # param_types) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) # param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FlashAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["UnfusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer"]) # , "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 1cbe6cbf0b..be3743bfe1 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -356,8 +356,8 @@ def convert_paged_to_nonpaged(self, layer_number: int): #, qkv_format: str): b=batch_size, ) - new_k_cache = new_k_cache.contiguous()[:actual_batch_size] - new_v_cache = new_v_cache.contiguous()[:actual_batch_size] + new_k_cache = new_k_cache[:actual_batch_size].contiguous() + new_v_cache = new_v_cache[:actual_batch_size].contiguous() return new_k_cache, new_v_cache @@ -724,7 +724,7 @@ def reset(self): def allocate_memory(self, layer_number): """Allocate memory for the cache""" - k_cache = torch.empty( + k_cache = torch.zeros( self.total_num_pages, self.page_size, self.num_heads, @@ -732,7 +732,7 @@ def allocate_memory(self, layer_number): dtype=self.dtype, device=torch.cuda.current_device(), ) - v_cache = torch.empty( + v_cache = torch.zeros( self.total_num_pages, self.page_size, self.num_heads, From 490e57a5efd878e14609f9b2c5e63b5121ba1824 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 28 Feb 2025 21:05:17 -0800 Subject: [PATCH 140/239] WIP: all modules/backend Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index a5228e1477..c9fe93185a 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -351,16 +351,16 @@ def get_tols(module, backend, dtype): } return tols[dtype] -@pytest.mark.parametrize("dtype", [torch.bfloat16]) # param_types) +@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["UnfusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) -@pytest.mark.parametrize("module", ["TransformerLayer"]) # , "DotProductAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): logger = logging.getLogger("test_paged_attn") - num_layers = 2 if module == "TransformerLayer" else 1 + num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 config = model_configs_infer[model] # figure out supported backends From 339bfa9de127812abec6600992ddef241b729fc4 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sat, 1 Mar 2025 10:30:54 -0800 Subject: [PATCH 141/239] WIP: minor cleanup Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 7 ++++ transformer_engine/pytorch/attention.py | 44 +-------------------- 2 files changed, 8 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index c9fe93185a..9f67a953b4 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -388,7 +388,10 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") + # flash-attn requires page size >= 256 if backend == "FlashAttention": + config_max_seqlen_q = config.max_seqlen_q + config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 config.max_seqlen_kv = 256 @@ -654,3 +657,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.serving_times = sim.arrival_times + sim.request_delays sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) + + if backend == "FlashAttention": + config.max_seqlen_q = config_max_seqlen_q + config.max_seqlen_kv = config_max_seqlen_kv diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9fe4d6fd97..4e7ba96deb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -594,9 +594,6 @@ def get_attention_backend( # Filter: QKV layout qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) - #qkv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - #) if qkv_format == "thd": if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") @@ -5271,27 +5268,8 @@ def forward( qkv_layout in QKVLayouts ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" - #qkv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - #) # get q_format and kv_format for training and inference qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) - #if inference_params is not None: #"_2" in qkv_layout: - # #qkv_format = qkv_layout.replace("paged_kv_", "") - # #q_format, kv_format = qkv_format.split("_2") - # q_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - # ) - # kv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] - # ) - # qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format - #else: - # qkv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - # ) - # q_format = qkv_format - # kv_format = qkv_format if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged( @@ -5865,24 +5843,7 @@ def forward( # get q_format and kv_format for training and inference qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) - #if inference_params is not None: #"_2" in qkv_layout: - # #qkv_format = qkv_layout.replace("paged_kv_", "") - # #q_format, kv_format = qkv_format.split("_2") - # q_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - # ) - # kv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] - # ) - # qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format - #else: - # qkv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - # ) - # q_format = qkv_format - # kv_format = qkv_format - - print('FA 0', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout) + # convert q, k, v to bshd if they are in sbhd # qkv_format is unchanged if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): @@ -6836,9 +6797,6 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - #qkv_format = "".join( - # [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - #) # get q_format and kv_format for training and inference qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) From 8d13f129e0b9a2467e7a986e6486f10722f11a8c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Mar 2025 18:31:30 +0000 Subject: [PATCH 142/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 80 +++++++----- transformer_engine/pytorch/attention.py | 116 +++++++++++------- transformer_engine/pytorch/csrc/extensions.h | 6 +- .../pytorch/csrc/extensions/attention.cu | 35 +++--- transformer_engine/pytorch/csrc/kv_cache.cuh | 15 +-- transformer_engine/pytorch/inference.py | 64 +++++----- 6 files changed, 182 insertions(+), 134 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 9f67a953b4..677b01b619 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -90,8 +90,8 @@ def __init__( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") - #self.context_lens[0] = 2 - #self.context_lens[2] = 3 + # self.context_lens[0] = 2 + # self.context_lens[2] = 3 # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -110,7 +110,7 @@ def __init__( self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to( dtype=torch.int32, device="cpu" ) - #self.arrival_times[2] = 0 + # self.arrival_times[2] = 0 # self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() @@ -208,6 +208,7 @@ def step(self, dynamic_fill: bool = True): self.t_batch_size = len(self.t_seq_ids) self.t_total_lens = self.t_ctx_lens + self.t_gen_lens + def get_model( module: torch.nn.Module, config: ModelConfig, @@ -216,7 +217,7 @@ def get_model( qkv_format: str = "bshd", num_layers: int = 1, mode: str = "reference", - ): +): reset_rng_states() sigma = 0.023 init_method = init_method_normal(sigma) @@ -268,13 +269,14 @@ def get_model( ] return model + def generate_args( module: torch.nn.Module, config: ModelConfig, dtype: torch.dtype, qkv_format: str = "bshd", mode: str = "full_inputs", - ): +): if mode == "full_inputs": warmup = False shapes = [] @@ -287,10 +289,20 @@ def generate_args( [config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk] ) shapes.append( - [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk] + [ + config.total_requests, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim_qk, + ] ) shapes.append( - [config.total_requests, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_v] + [ + config.total_requests, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim_v, + ] ) elif mode == "sample_args": warmup = True @@ -320,8 +332,10 @@ def generate_args( ] elif module == "TransformerLayer": return [ - 0.01 * torch.randint( - -100, 100, + 0.01 + * torch.randint( + -100, + 100, shapes[i], device="cuda", dtype=dtype, @@ -330,7 +344,8 @@ def generate_args( ] elif module == "DotProductAttention": return [ - 0.1 * torch.randn( + 0.1 + * torch.randn( *shapes[i], device="cuda", dtype=dtype, @@ -338,6 +353,7 @@ def generate_args( for i in range(num_tensors) ] + def get_tols(module, backend, dtype): if module == "TransformerLayer": tols = { @@ -351,6 +367,7 @@ def get_tols(module, backend, dtype): } return tols[dtype] + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @@ -416,14 +433,14 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if module == "TransformerLayer": full_output = full_inputs for m in model: - print('xxxxxxxxxxxxxxxxxxxxxxxx ', type(full_output)) + print("xxxxxxxxxxxxxxxxxxxxxxxx ", type(full_output)) full_output = m( full_output[0] if isinstance(full_output, List) else full_output, # rotary_pos_emb=rotary_freqs, ) - print("full", full_output[0,:2,:8]) - print("full", full_output[1,:7,:8]) - print("full", full_output[2,:3,:8]) + print("full", full_output[0, :2, :8]) + print("full", full_output[1, :7, :8]) + print("full", full_output[2, :3, :8]) # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -460,7 +477,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda head_dim_q=config.head_dim_qk, max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, - #allow_query_conversion=backend != "FusedAttention", + # allow_query_conversion=backend != "FusedAttention", ) for layer_number in range(1, num_layers + 1): inference_params.allocate_memory(layer_number, qkv_format) @@ -475,7 +492,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist())) inference_params.pre_step(step_dict) - sample_args = generate_args(module, config, dtype, qkv_format=qkv_format, mode="sample_args") + sample_args = generate_args( + module, config, dtype, qkv_format=qkv_format, mode="sample_args" + ) sample_kwargs = {} sample_kwargs["cu_seqlens_q"] = torch.linspace( 0, @@ -495,13 +514,16 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - model = [make_graphed_callables( - model[i], - sample_args, - num_warmup_iters=10, - fp8_enabled=False, - sample_kwargs=sample_kwargs, - ) for i in range(num_layers)] + model = [ + make_graphed_callables( + model[i], + sample_args, + num_warmup_iters=10, + fp8_enabled=False, + sample_kwargs=sample_kwargs, + ) + for i in range(num_layers) + ] sim.reset() inference_params.reset() @@ -593,11 +615,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda inference_params.cache_manager.print_cache() incremental_output = incremental_inputs for m in model: - print('xxxxdgdg ', type(incremental_output)) + print("xxxxdgdg ", type(incremental_output)) incremental_output = m( - *incremental_output - if isinstance(incremental_output, List) - else incremental_output, + *incremental_output if isinstance(incremental_output, List) else incremental_output, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, @@ -610,13 +630,13 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # compare results tol = get_tols(module, backend, dtype) for i, seq in enumerate(sim.t_seq_ids): - #token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 + # token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": print(i, seq, sim.t_total_lens, sim.step_lens, token_index) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(incremental_output[i, token_index, :4]) - #print(incremental_output[i, sim.step_lens[i] - 1, :4]) + # print(incremental_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], # incremental_output[:sim.step_lens[i] - 1, i, :], @@ -651,7 +671,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ) sim.t += 1 sim.t_gen_lens = sim.t_gen_lens + 1 - #if sim.t == 1: + # if sim.t == 1: # break sim.serving_times = sim.arrival_times + sim.request_delays diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4e7ba96deb..d64655be84 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5273,7 +5273,8 @@ def forward( if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged( - self.layer_number) #, inference_params.input_qkv_format + self.layer_number + ) # , inference_params.input_qkv_format if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now @@ -5281,9 +5282,7 @@ def forward( x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] if qkv_format == "sbhd_2bshd": - key_layer, value_layer = [ - x.transpose(0, 1) for x in [key_layer, value_layer] - ] + key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]] total_tokens, batch_size = None, None if qkv_format == "thd_2bshd": @@ -5295,7 +5294,7 @@ def forward( inference_params.max_ctx_len, query_layer.shape[-2], query_layer.shape[-1], - ) + ) query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] @@ -5305,7 +5304,7 @@ def forward( key_layer.shape[0], ) - if "padding" in attn_mask_type: # and qkv_format in ["bshd", "sbhd"]: + if "padding" in attn_mask_type: # and qkv_format in ["bshd", "sbhd"]: attention_mask = get_attn_mask( batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv ) @@ -5465,7 +5464,7 @@ def forward( context_layer.shape[-2], context_layer.shape[-1], total_tokens, - ) + ) # [tq, np, hn] --> [tq, hp] context_layer = context_layer.view(total_tokens, -1) @@ -5507,6 +5506,7 @@ def backward( dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) return dq, dk, dv + def get_qkv_format( qkv_layout: str = "bshd_bshd_bshd", inference_params: InferenceParams = None, @@ -5529,9 +5529,9 @@ def get_qkv_format( kv_format: str Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. """ - if inference_params is not None: #"_2" in qkv_layout: - #qkv_format = qkv_layout.replace("paged_kv_", "") - #q_format, kv_format = qkv_format.split("_2") + if inference_params is not None: # "_2" in qkv_layout: + # qkv_format = qkv_layout.replace("paged_kv_", "") + # q_format, kv_format = qkv_format.split("_2") q_format = "".join( [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] ) @@ -5547,6 +5547,7 @@ def get_qkv_format( kv_format = qkv_format return qkv_format, q_format, kv_format + def get_qkv_layout( q: torch.Tensor, k: torch.Tensor, @@ -5859,7 +5860,8 @@ def forward( ) else: query_layer, key_layer, value_layer = [ - x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer) + x.transpose(0, 1).contiguous() + for x in (query_layer, key_layer, value_layer) ] elif q_format == "sbhd" and kv_format == "bshd": query_layer = query_layer.transpose(0, 1).contiguous() @@ -5879,13 +5881,17 @@ def forward( ] elif q_format == "sbhd" and kv_format == "bshd": query_layer._data = query_layer._data.transpose(0, 1).contiguous() - query_layer = Float8Tensor.make_like(query_layer, data=query_layer._data, shape=query_layer._data.shape) + query_layer = Float8Tensor.make_like( + query_layer, data=query_layer._data, shape=query_layer._data.shape + ) if context_parallel: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - print('FA 1', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout) + print( + "FA 1", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout + ) # get accurate batch_size, max_seqlen and cu_seqlens batch_size = None if inference_params is None: @@ -5896,7 +5902,9 @@ def forward( max_seqlen_kv *= cp_size if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" + assert ( + not context_parallel + ), "Padding mask not supported with context parallelism!" # [b * s, h, d] query_layer, key_layer, value_layer = [ @@ -5975,8 +5983,10 @@ def forward( num_heads, head_dim, batch_size * context_len, - ) - query_layer = Float8Tensor.make_like(query_layer, data=query_layer._data, shape=query_layer._data.shape) + ) + query_layer = Float8Tensor.make_like( + query_layer, data=query_layer._data, shape=query_layer._data.shape + ) else: query_layer = tex.convert_bshd_to_thd( query_layer, @@ -5986,7 +5996,7 @@ def forward( num_heads, head_dim, batch_size * context_len, - ) + ) if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] @@ -6037,7 +6047,11 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] - if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type and inference_params is None: + if ( + qkv_format in ["bshd", "sbhd"] + and "padding" not in attn_mask_type + and inference_params is None + ): func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: func = ( @@ -6053,7 +6067,9 @@ def forward( fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged - else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[:batch_size] + else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[ + :batch_size + ] ) if _use_flash_attn_3: fa_3_optional_forward_kwargs = {} @@ -6063,7 +6079,9 @@ def forward( fa_3_optional_forward_kwargs["page_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged - else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[:batch_size] + else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[ + :batch_size + ] ) if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] @@ -6092,8 +6110,8 @@ def convert_to_torch_float8(tensor, dtype): fa_3_optional_forward_kwargs["descale_q"] = ( query_layer._scale_inv.unsqueeze(0) ) - fa_3_optional_forward_kwargs["descale_k"] = ( - key_layer._scale_inv.unsqueeze(0) + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( + 0 ) fa_3_optional_forward_kwargs["descale_v"] = ( value_layer._scale_inv.unsqueeze(0) @@ -6116,8 +6134,7 @@ def convert_to_torch_float8(tensor, dtype): if _flash_attn_3_0_0_beta: e.args = ( e.args[0] - + ". Please update your flash-attn v3 (beta) installation" - " as it " + + ". Please update your flash-attn v3 (beta) installation as it " + "may have added more supported arguments to its API. \n" + _flash_attn_3_installation_steps, ) + e.args[1:] @@ -6141,14 +6158,11 @@ def convert_to_torch_float8(tensor, dtype): ) if inference_params is None: - if ( - qkv_format in ["sbhd", "bshd"] - and "padding" in attn_mask_type - ): + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) elif qkv_format in ["bshd", "sbhd_2bshd"]: # convert back to bshd_2bshd from thd_2bshd - #batch_size, context_len, num_heads, head_dim = output.shape + # batch_size, context_len, num_heads, head_dim = output.shape if isinstance(query_layer, Float8Tensor): output._data = tex.convert_thd_to_bshd( output._data, @@ -6157,7 +6171,7 @@ def convert_to_torch_float8(tensor, dtype): context_len, num_heads, head_dim, - ) + ) output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape) else: output = tex.convert_thd_to_bshd( @@ -6167,7 +6181,7 @@ def convert_to_torch_float8(tensor, dtype): context_len, num_heads, head_dim, - ) + ) if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) @@ -6813,7 +6827,9 @@ def forward( max_seqlen_q *= cp_size max_seqlen_kv *= cp_size if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" + assert ( + not context_parallel + ), "Padding mask not supported with context parallelism!" if cu_seqlens_q is None or cu_seqlens_kv is None: if attention_mask is None: @@ -7557,14 +7573,12 @@ def forward( head_dim_v == self.hidden_size_per_attention_head_v ), f"Values have head_dim = {head_dim_v}, " "but expected head_dim = {self.hidden_size_per_attention_head_v}!" - assert ( - num_gqa_groups == self.num_gqa_groups_per_partition - ), ( + assert num_gqa_groups == self.num_gqa_groups_per_partition, ( "Keys and values must have num_gqa_group =" f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}." ) - # checks for attention mask + # checks for attention mask if attn_mask_type is None: attn_mask_type = self.attn_mask_type else: @@ -7653,18 +7667,18 @@ def forward( # update KV cache and retrieve full KV tokens ( - #query_layer, + # query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, - #max_seqlen_q, + # max_seqlen_q, max_seqlen_kv, qkv_format, ) = inference_params.step( self.layer_number, - #query_layer, + # query_layer, key_layer, value_layer, qkv_format, @@ -7672,7 +7686,9 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None - print('FA DPA', [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format)#, qkv_layout) + print( + "FA DPA", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format + ) # , qkv_layout) # get accurate qkv_layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( @@ -7683,7 +7699,11 @@ def forward( q_format, kv_format, ) = get_qkv_layout( - query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format, inference_params=inference_params, + query_layer._data, + key_layer._data, + value_layer._data, + qkv_format=qkv_format, + inference_params=inference_params, ) else: ( @@ -7694,7 +7714,11 @@ def forward( q_format, kv_format, ) = get_qkv_layout( - query_layer, key_layer, value_layer, qkv_format=qkv_format, inference_params=inference_params, + query_layer, + key_layer, + value_layer, + qkv_format=qkv_format, + inference_params=inference_params, ) # adjust max_seqlen and cu_seqlens @@ -8014,12 +8038,12 @@ def forward( inference_params=inference_params, ) - #if inference_params is not None: + # if inference_params is not None: ## inference_params.is_output_right_aligned = use_flash_attention # output = inference_params.post_step(self.layer_number, output) - #print(output[0,-2:,:8]) - #print(output[1,:,:8]) - #print(output[2,-3:,:8]) + # print(output[0,-2:,:8]) + # print(output[1,:,:8]) + # print(output[2,-3:,:8]) return output diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index dd069542d9..f73dcf140f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -70,8 +70,10 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d); -at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d, int t); +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, + int h, int d); +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, + int h, int d, int t); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index b43be4c267..66f27f87c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1034,7 +1034,8 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t **************************************************************************************************/ template -void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d) { +void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, + int b, int max_seq_len, int h, int d) { transformer_engine::fused_attn:: convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(tensor.data_ptr()), @@ -1042,7 +1043,8 @@ void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at:: b, max_seq_len, h, d); } -at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d) { +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, + int h, int d) { std::vector shape = {b, max_seq_len, h, d}; at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); if (new_tensor.scalar_type() == at::ScalarType::Half) { @@ -1071,16 +1073,17 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, **************************************************************************************************/ template -void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, - at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d) { +void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, + int b, int max_seq_len, int h, int d) { transformer_engine::fused_attn:: convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(tensor.data_ptr()), - reinterpret_cast(new_tensor.data_ptr()), - cu_seqlens.data_ptr(), b, max_seq_len, h, d); + reinterpret_cast(new_tensor.data_ptr()), cu_seqlens.data_ptr(), + b, max_seq_len, h, d); } -at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, int h, int d, int t) { +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, + int h, int d, int t) { std::vector shape = {t, h, d}; at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); if (tensor.scalar_type() == at::ScalarType::Half) { @@ -1119,11 +1122,10 @@ at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, template void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, - at::Tensor v_cache, at::Tensor page_table, - at::Tensor cu_new_lens, at::Tensor cu_cached_lens, - NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged) { + at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens, + at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, + int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq, bool is_non_paged) { if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) { if (is_non_paged) { @@ -1145,11 +1147,10 @@ void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_ } } -void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, - at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens, - at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, int d_k, - int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged) { +void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, + at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, + NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged) { NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && new_k.scalar_type() == new_v.scalar_type() && new_k.scalar_type() == k_cache.scalar_type(), diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index 4a693de3f7..bd585618bb 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -9,7 +9,8 @@ namespace transformer_engine { namespace fused_attn { template -__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, int b, int max_seq_len, int h, int d) { +__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, + int b, int max_seq_len, int h, int d) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d; int thd_offset = cu_seqlens[batch_idx] * h * d; @@ -24,7 +25,7 @@ __global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tenso template __global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, - int b, int max_seq_len, int h, int d) { + int b, int max_seq_len, int h, int d) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]; int num_elts = seqlen * h * d; @@ -68,11 +69,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } } -// if (blockIdx.x == 0) { -// for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { -// batch_indices[batch_idx] = batch_idx; -// } -// } + // if (blockIdx.x == 0) { + // for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { + // batch_indices[batch_idx] = batch_idx; + // } + // } } template diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index be3743bfe1..67c00dc781 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -152,7 +152,7 @@ def __init__( max_ctx_len: int = None, qkv_format: str = "bshd", cache_manager: KVCacheManager = None, - #allow_query_conversion: bool = True, + # allow_query_conversion: bool = True, ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv @@ -164,9 +164,9 @@ def __init__( _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - #self.allow_query_conversion = allow_query_conversion and ( + # self.allow_query_conversion = allow_query_conversion and ( # _NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN - #) + # ) if not self.is_paged: cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager @@ -206,7 +206,7 @@ def __init__( if qkv_format == "thd": assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" self.max_ctx_len = max_ctx_len - #if self.allow_query_conversion: + # if self.allow_query_conversion: # # query is converted to 'bshd' for certain backends # assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" # assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" @@ -217,7 +217,7 @@ def __init__( # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: + if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: self.output_qkv_format = self.cache_qkv_format else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format @@ -234,7 +234,7 @@ def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() self.cache_manager.reset() - #if self.input_qkv_format == "thd" and self.allow_query_conversion: + # if self.input_qkv_format == "thd" and self.allow_query_conversion: # for _, q_buffer in self.q_buffer.items(): # q_buffer.fill_(0) @@ -282,15 +282,15 @@ def allocate_memory(self, layer_number: int, qkv_format: str): device=torch.cuda.current_device(), ) -# if qkv_format == "thd" and self.allow_query_conversion: -# self.q_buffer[layer_number] = torch.zeros( -# self.max_batch_size, -# self.max_ctx_len, -# self.num_heads_q, -# self.head_dim_q, -# dtype=self.dtype, -# device=torch.cuda.current_device(), -# ) + # if qkv_format == "thd" and self.allow_query_conversion: + # self.q_buffer[layer_number] = torch.zeros( + # self.max_batch_size, + # self.max_ctx_len, + # self.num_heads_q, + # self.head_dim_q, + # dtype=self.dtype, + # device=torch.cuda.current_device(), + # ) def pre_step( self, @@ -323,7 +323,7 @@ def get_seqlens_pre_step(self): """Get cached sequence lengths for current iteration before adding step_dict.values""" return self.sequences_pre - def convert_paged_to_nonpaged(self, layer_number: int): #, qkv_format: str): + def convert_paged_to_nonpaged(self, layer_number: int): # , qkv_format: str): """ Convert k_cache and v_cache from paged to non-paged format. This is used by the UnfusedDotProductAttention backend. Both k_cache and v_cache are assumed to be @@ -364,7 +364,7 @@ def convert_paged_to_nonpaged(self, layer_number: int): #, qkv_format: str): def step( self, layer_number: int, - #new_q: torch.Tensor, + # new_q: torch.Tensor, new_k: torch.Tensor, new_v: torch.Tensor, qkv_format: str, @@ -402,22 +402,22 @@ def step( qkv_format: str Updated qkv_format, e.g. the input 'thd' format may become 'thd_2bshd' after step() """ - print('self.sequences', self.sequences) + print("self.sequences", self.sequences) self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: + if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: self.output_qkv_format = self.cache_qkv_format else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - #q_buffer = new_q - #if qkv_format == "bshd": + # q_buffer = new_q + # if qkv_format == "bshd": # self.max_seqlen_q = new_q.shape[1] # q_buffer = new_q.contiguous() - #if qkv_format == "sbhd": + # if qkv_format == "sbhd": # self.max_seqlen_q = new_q.shape[0] # if self.allow_query_conversion: # q_buffer = new_q.transpose(0, 1).contiguous() - #if qkv_format == "thd": + # if qkv_format == "thd": # self.max_seqlen_q = self.max_ctx_len # if self.allow_query_conversion: # q_buffer = self.q_buffer[layer_number] @@ -442,13 +442,13 @@ def step( ) return ( - #q_buffer, + # q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, - #self.max_seqlen_q, + # self.max_seqlen_q, self.max_seqlen_kv, self.output_qkv_format, ) @@ -461,12 +461,12 @@ def post_step( """ Process the attention output in order to return it to the original qkv_format. """ - print('post step ',self.input_qkv_format) - #if self.input_qkv_format == "bshd": + print("post step ", self.input_qkv_format) + # if self.input_qkv_format == "bshd": # output = output[: self.batch_size, : self.max_seqlen_q].contiguous() - #if self.input_qkv_format == "sbhd": # and self.allow_query_conversion: + # if self.input_qkv_format == "sbhd": # and self.allow_query_conversion: # output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() - #if self.input_qkv_format == "thd": # and self.allow_query_conversion: + # if self.input_qkv_format == "thd": # and self.allow_query_conversion: # output_buffer = self.q_orig[layer_number] # tex.convert_bshd_to_thd( # output, @@ -539,7 +539,7 @@ def allocate_memory(self, layer_number): ) self.batch_indices_post = torch.range( 0, - self.max_batch_size-1, + self.max_batch_size - 1, dtype=torch.int32, device=torch.cuda.current_device(), ) @@ -567,7 +567,7 @@ def pre_step( ) ).to(dtype=torch.int32, device="cpu") ) - print('self.batch_indices', self.batch_indices) + print("self.batch_indices", self.batch_indices) # Advance unfinished sequences for i in unfinished_seqs: @@ -631,7 +631,7 @@ def step( batch_size = new_k.shape[1] ctx_len = new_k.shape[0] - #print('non-paged self.batch_indices', self.batch_indices) + # print('non-paged self.batch_indices', self.batch_indices) tex.copy_to_kv_cache( new_k, new_v, From 6bd11424bca954ecc075e4d20b7d16a9fc8e4385 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 1 Mar 2025 11:55:47 -0800 Subject: [PATCH 143/239] WIP: FlashAttention on Hopper with 2.7.3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 12 +++++++----- transformer_engine/pytorch/attention.py | 8 ++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 677b01b619..fa83088d2f 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -19,6 +19,7 @@ from transformer_engine.pytorch.attention import ( DotProductAttention, InferenceParams, + _flash_attn_3_plus, ) from transformer_engine.pytorch.utils import ( get_device_compute_capability, @@ -358,7 +359,7 @@ def get_tols(module, backend, dtype): if module == "TransformerLayer": tols = { torch.half: 4e-3, - torch.bfloat16: 3e-2, + torch.bfloat16: 3.5e-2, } if module == "DotProductAttention": tols = { @@ -372,7 +373,7 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FlashAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): @@ -406,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") # flash-attn requires page size >= 256 - if backend == "FlashAttention": + if backend == "FlashAttention" and not _flash_attn_3_plus: config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 @@ -448,7 +449,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda page_size = None total_num_pages = None if is_paged: - page_size = 256 if backend == "FlashAttention" else 16 + page_size = 256 if backend == "FlashAttention" and not _flash_attn_3_plus else 16 config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) else: @@ -624,6 +625,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, ) + print('ddddddddddd ', len(incremental_output), type(incremental_output)) incremental_output = [incremental_output] incremental_output = incremental_output[0] @@ -678,6 +680,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention": + if backend == "FlashAttention" and not _flash_attn_3_plus: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d64655be84..5009fcea4f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -130,6 +130,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_5_7_plus = False _flash_attn_2_6_plus = False _flash_attn_2_7_plus = False +_flash_attn_3_plus = False flash_attn_cuda_bwd = None flash_attn_func = None @@ -178,6 +179,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_6_plus = _flash_attn_version >= PkgVersion("2.6.0") _flash_attn_2_7_plus = _flash_attn_version >= PkgVersion("2.7.0") + _flash_attn_3_plus = _flash_attn_version >= PkgVersion("3.0") elif ( torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN ): @@ -6064,6 +6066,7 @@ def forward( fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if inference_params is not None: + # use page_table to support thd_2bshd format when is_paged=False fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged @@ -6076,6 +6079,11 @@ def forward( fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["deterministic"] = self.deterministic if inference_params is not None: + # use page_table to support thd_2bshd format when is_paged=False + # 2.7.3+ -> page_table + # git clone --recursive -b v2.7.3 https://github.com/Dao-AILab/flash-attention.git + # MAX_JOBS=6 FLASH_ATTN_CUDA_ARCHS=90 pip install -v -e . + assert _flash_attn_3_plus, "Please install flash-attn from v3" fa_3_optional_forward_kwargs["page_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged From ee006c4cd5148193c62679a085fe39f181e3073c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 1 Mar 2025 16:10:32 -0800 Subject: [PATCH 144/239] WIP: FlashAttention + v3 from 39e7179 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 14 +- tests/pytorch/fused_attn/test_paged_attn.py | 8 +- transformer_engine/pytorch/attention.py | 187 +++++++++++++------- transformer_engine/pytorch/module/linear.py | 1 + 4 files changed, 137 insertions(+), 73 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 775bf1651e..90e77ffb45 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -204,13 +204,13 @@ def test(): return available_backends, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + #with logging_context(): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, fused_attn_backends diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index fa83088d2f..cc5670156d 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.attention import ( DotProductAttention, InferenceParams, - _flash_attn_3_plus, + _use_flash_attn_3, ) from transformer_engine.pytorch.utils import ( get_device_compute_capability, @@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") # flash-attn requires page size >= 256 - if backend == "FlashAttention" and not _flash_attn_3_plus: + if backend == "FlashAttention" and not _use_flash_attn_3: config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 @@ -449,7 +449,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda page_size = None total_num_pages = None if is_paged: - page_size = 256 if backend == "FlashAttention" and not _flash_attn_3_plus else 16 + page_size = 256 if backend == "FlashAttention" and not _use_flash_attn_3 else 16 config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) else: @@ -680,6 +680,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention" and not _flash_attn_3_plus: + if backend == "FlashAttention" and not _use_flash_attn_3: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5009fcea4f..6032206ece 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -130,7 +130,6 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_5_7_plus = False _flash_attn_2_6_plus = False _flash_attn_2_7_plus = False -_flash_attn_3_plus = False flash_attn_cuda_bwd = None flash_attn_func = None @@ -179,7 +178,6 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_6_plus = _flash_attn_version >= PkgVersion("2.6.0") _flash_attn_2_7_plus = _flash_attn_version >= PkgVersion("2.7.0") - _flash_attn_3_plus = _flash_attn_version >= PkgVersion("3.0") elif ( torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN ): @@ -196,22 +194,22 @@ def _get_supported_versions(version_min, version_max): _flash_attn_version, ) -# Detect flash-attn v3 in the environment -# This section will be removed when FA3 is released as a regular FA package, -# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0 +# Detect flash_attn_3 in the environment _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False _use_flash_attn_3 = False -# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved -# https://github.com/Dao-AILab/flash-attention/issues/1452 _flash_attn_3_installation_steps = """\ -(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" -(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` -(3) mkdir -p $python_path/flashattn_hopper -(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/ + git checkout 39e7197 + cd hopper/ + python setup.py install + python_path=`python -c "import site; print(site.getsitepackages()[0])"` + mkdir -p $python_path/flash_attn_3 + wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/refs/heads/main/hopper/flash_attn_interface.py""" try: - _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) + _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: fa_logger.debug( @@ -219,26 +217,26 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_installation_steps, ) else: - from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flashattn_hopper.flash_attn_interface import ( + from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_3.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flashattn_hopper.flash_attn_interface import ( + from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn_with_kvcache_v3, ) - from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, - ) - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, - ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + #from flash_attn_3.flash_attn_interface import ( + # _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, + #) + #from flash_attn_3.flash_attn_interface import ( + # _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, + #) _flash_attn_3_is_installed = True _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") _use_flash_attn_3 = True - + _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -986,7 +984,7 @@ def get_attention_backend( _flash_attn_max_version, ), ) - if use_flash_attention and not _flash_attn_is_installed: + if use_flash_attention and not _flash_attn_is_installed and not _flash_attn_3_is_installed: use_flash_attention = False available_backends[0] = False @@ -2055,10 +2053,11 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - if qkv_format == "thd": - flash_attn_fwd = _flash_attn_varlen_fwd_v3 - else: - flash_attn_fwd = _flash_attn_fwd_v3 + #if qkv_format == "thd": + # flash_attn_fwd = _flash_attn_varlen_fwd_v3 + #else: + # flash_attn_fwd = _flash_attn_fwd_v3 + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) else: if qkv_format == "thd": @@ -3822,10 +3821,11 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - if qkv_format == "thd": - flash_attn_fwd = _flash_attn_varlen_fwd_v3 - else: - flash_attn_fwd = _flash_attn_fwd_v3 + #if qkv_format == "thd": + # flash_attn_fwd = _flash_attn_varlen_fwd_v3 + #else: + # flash_attn_fwd = _flash_attn_fwd_v3 + flash_attn_fwd = _flash_attn_fwd_v3 else: if qkv_format == "thd": flash_attn_fwd = _flash_attn_varlen_fwd @@ -4286,10 +4286,11 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - if qkv_format == "thd": - flash_attn_fwd = _flash_attn_varlen_fwd_v3 - else: - flash_attn_fwd = _flash_attn_fwd_v3 + #if qkv_format == "thd": + # flash_attn_fwd = _flash_attn_varlen_fwd_v3 + #else: + # flash_attn_fwd = _flash_attn_fwd_v3 + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = window_size else: if qkv_format == "thd": @@ -5895,7 +5896,7 @@ def forward( "FA 1", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout ) # get accurate batch_size, max_seqlen and cu_seqlens - batch_size = None + batch_size, context_len, total_tokens = None, None, None if inference_params is None: if qkv_format in ["sbhd", "bshd"]: batch_size = query_layer.shape[0] @@ -5999,7 +6000,37 @@ def forward( head_dim, batch_size * context_len, ) + #if _use_flash_attn_3 and qkv_format in ["thd_2bshd"]: + # total_tokens, num_heads, head_dim = query_layer.shape + # batch_size = key_layer.shape[0] + # cu_seqlens_q = cu_seqlens_q[: batch_size + 1] + # cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] + # # convert to bshd_2bshd for flash_attn_with_kvcache_v3 + # if isinstance(query_layer, Float8Tensor): + # query_layer._data = tex.convert_thd_to_bshd( + # query_layer._data, + # cu_seqlens_q, + # batch_size, + # inference_params.max_ctx_len, + # num_heads, + # head_dim, + # ) + # query_layer = Float8Tensor.make_like( + # query_layer, data=query_layer._data, shape=query_layer._data.shape + # ) + # else: + # query_layer = tex.convert_thd_to_bshd( + # query_layer, + # cu_seqlens_q, + # batch_size, + # inference_params.max_ctx_len, + # num_heads, + # head_dim, + # ) + print( + "FA 1", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout + ) if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -6056,15 +6087,17 @@ def forward( ): func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: - func = ( - flash_attn_varlen_func - if not _use_flash_attn_3 - else flash_attn_varlen_func_v3 - ) - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) - fa_optional_forward_args_thd.append(max_seqlen_q) - fa_optional_forward_args_thd.append(max_seqlen_kv) + if not _use_flash_attn_3: + func = flash_attn_varlen_func + elif inference_params is None: + func = flash_attn_varlen_func_v3 + else: + func = flash_attn_with_kvcache_v3 + if not _use_flash_attn_3 or inference_params is None: + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append(max_seqlen_q) + fa_optional_forward_args_thd.append(max_seqlen_kv) if inference_params is not None: # use page_table to support thd_2bshd format when is_paged=False fa_optional_forward_kwargs["block_table"] = ( @@ -6077,20 +6110,21 @@ def forward( if _use_flash_attn_3: fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size - fa_3_optional_forward_kwargs["deterministic"] = self.deterministic - if inference_params is not None: - # use page_table to support thd_2bshd format when is_paged=False - # 2.7.3+ -> page_table - # git clone --recursive -b v2.7.3 https://github.com/Dao-AILab/flash-attention.git - # MAX_JOBS=6 FLASH_ATTN_CUDA_ARCHS=90 pip install -v -e . - assert _flash_attn_3_plus, "Please install flash-attn from v3" - fa_3_optional_forward_kwargs["page_table"] = ( - inference_params.cache_manager.page_table[:batch_size] - if inference_params.is_paged - else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[ - :batch_size - ] - ) + if inference_params is None: + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic + else: + fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q + fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q + cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens + if inference_params.is_paged: + fa_3_optional_forward_kwargs["page_table"] = inference_params.cache_manager.page_table[:batch_size] + #fa_3_optional_forward_kwargs["page_table"] = ( + # inference_params.cache_manager.page_table[:batch_size] + # else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[ + # :batch_size + # ] + #) if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -6129,7 +6163,7 @@ def convert_to_torch_float8(tensor, dtype): for x in [query_layer, key_layer, value_layer] ) try: - output, _ = func( + output = func( query_layer, key_layer, value_layer, @@ -6138,6 +6172,8 @@ def convert_to_torch_float8(tensor, dtype): causal="causal" in attn_mask_type, **fa_3_optional_forward_kwargs, ) + if isinstance(output, List) or isinstance(output, Tuple): + output = output[0] except TypeError as e: if _flash_attn_3_0_0_beta: e.args = ( @@ -6190,6 +6226,33 @@ def convert_to_torch_float8(tensor, dtype): num_heads, head_dim, ) + #elif _use_flash_attn_3 and qkv_format in ["thd_2bshd"]: + # # if flash_attn_2, use flash_attn_varlen_func_v3 and thd_2bshd + # # if flash_attn_3, use flash_attn_with_kvcache_v3 and bshd + # #batch_size = cu_seqlens_q.shape[0] - 1 + # #total_tokens, num_heads, head_dim = query_layer.shape + # # convert back to thd_2bshd from bshd + # if isinstance(query_layer, Float8Tensor): + # output._data = tex.convert_bshd_to_thd( + # output._data, + # cu_seqlens_q, + # batch_size, + # inference_params.max_ctx_len, + # num_heads, + # head_dim, + # total_tokens, + # ) + # output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape) + # else: + # output = tex.convert_bshd_to_thd( + # output, + # cu_seqlens_q, + # batch_size, + # inference_params.max_ctx_len, + # num_heads, + # head_dim, + # total_tokens, + # ) if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 83dc652c62..29cefa2586 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -111,6 +111,7 @@ def forward( # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape + print('inp_shape', inp_shape, weight.shape) assert inp_shape[-1] == in_features, "GEMM not possible" tp_world_size = get_distributed_world_size(tp_group) From bd93082ae9ed84f07797bd31131b36cdf942b73f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 1 Mar 2025 17:04:05 -0800 Subject: [PATCH 145/239] WIP: FlashAttention + v3 + FP8 + WIP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 114 +++++++++++++------- transformer_engine/pytorch/attention.py | 57 +++++----- 2 files changed, 105 insertions(+), 66 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index cc5670156d..d2652b6ce3 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -13,6 +13,8 @@ from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables +from transformer_engine.common import recipe +from transformer_engine.pytorch import fp8_autocast, fp8_model_init from transformer_engine.pytorch.transformer import ( TransformerLayer, ) @@ -230,6 +232,16 @@ def get_model( if mode == "inference": attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding" + fp8_mha = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_mha, + fp8_mha=fp8_mha, + ) + if module == "TransformerLayer": hidden_size = config.head_dim_qk * config.num_heads model = [ @@ -254,20 +266,21 @@ def get_model( for layer_number in range(1, num_layers + 1) ] if module == "DotProductAttention": - model = [ - DotProductAttention( - kv_channels=config.head_dim_qk, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - layer_number=layer_number, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - ) - .cuda() - .eval() - for layer_number in range(1, num_layers + 1) - ] + with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): + model = [ + DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] return model @@ -365,6 +378,7 @@ def get_tols(module, backend, dtype): tols = { torch.half: 1e-3, torch.bfloat16: 1e-2, + torch.float8_e4m3fn: 3e-2, } return tols[dtype] @@ -381,6 +395,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 config = model_configs_infer[model] + #dtype = torch.float8_e4m3fn + is_fp8 = True + # figure out supported backends inference_params_qkv_format = "bshd" if is_paged: @@ -395,6 +412,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda pad_between_seqs=False, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + print('available_backends', available_backends) if backend == "FlashAttention" and not flash_attn_supported: pytest.skip("FlashAttention backend is not supported") if backend == "FusedAttention" and not fused_attn_supported: @@ -412,6 +430,10 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 config.max_seqlen_kv = 256 + print('qkv format', qkv_format) + #if dtype == torch.float8_e4m3fn and qkv_format != "thd": + if is_fp8 and (qkv_format != "thd" or module != "DotProductAttention"):# or dtype != torch.float16): + pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") # create full model model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference") @@ -486,6 +508,15 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # create inference model model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="inference") + fp8_mha = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_mha, + fp8_mha=fp8_mha, + ) # graph the model if necessary if is_cuda_graph: t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") @@ -515,16 +546,18 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - model = [ - make_graphed_callables( - model[i], - sample_args, - num_warmup_iters=10, - fp8_enabled=False, - sample_kwargs=sample_kwargs, - ) - for i in range(num_layers) - ] + with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + model = [ + make_graphed_callables( + model[i], + sample_args, + num_warmup_iters=10, + fp8_enabled=fp8_mha, #False, + sample_kwargs=sample_kwargs, + fp8_recipe=fp8_recipe, + ) + for i in range(num_layers) + ] sim.reset() inference_params.reset() @@ -615,22 +648,23 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if inference_params.is_paged: inference_params.cache_manager.print_cache() incremental_output = incremental_inputs - for m in model: - print("xxxxdgdg ", type(incremental_output)) - incremental_output = m( - *incremental_output if isinstance(incremental_output, List) else incremental_output, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - inference_params=inference_params, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - ) - print('ddddddddddd ', len(incremental_output), type(incremental_output)) - incremental_output = [incremental_output] + with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + for m in model: + print("xxxxdgdg ", type(incremental_output)) + incremental_output = m( + *incremental_output if isinstance(incremental_output, List) else incremental_output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + inference_params=inference_params, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + ) + print('ddddddddddd ', len(incremental_output), type(incremental_output)) + incremental_output = [incremental_output] incremental_output = incremental_output[0] # compare results - tol = get_tols(module, backend, dtype) + tol = get_tols(module, backend, dtype=torch.float8_e4m3fn) for i, seq in enumerate(sim.t_seq_ids): # token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 token_index = sim.step_lens[i] - 1 @@ -660,9 +694,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda rtol=tol, ) if qkv_format == "thd": - # print('i ', i, seq, cu_seqlens_q) - # print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - # print(incremental_output[cu_seqlens_q[i + 1] - 1, :4]) + print('i ', i, seq, cu_seqlens_q) + print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + print(incremental_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], # incremental_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6032206ece..dfe8959ba1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -452,6 +452,8 @@ def get_attention_backend( # install an appropriate FA version. global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + # Filter: Environment variables use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) @@ -481,14 +483,14 @@ def get_attention_backend( torch.Tensor, Float8Tensor, ]: - if use_flash_attention and _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_flash_attention = False + #if use_flash_attention and _flash_attn_is_installed: + # logger.debug( + # "Disabling FlashAttention due to unsupported QKV data type. " + # "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + # "Found: qkv_dtype = %s.", + # qkv_dtype, + # ) + #use_flash_attention = False if use_fused_attention: logger.debug( "Disabling FusedAttention due to unsupported QKV data type. " @@ -528,9 +530,9 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention: + if use_flash_attention and q_format != "thd": use_flash_attention = False - logger.debug("Disabling FlashAttention for FP8 KV caching") + logger.debug("Disabling FlashAttention for FP8 KV caching (only THD is supported)") if use_fused_attention and inference_params.is_paged: use_fused_attention = False logger.debug( @@ -593,7 +595,6 @@ def get_attention_backend( use_fused_attention = False # Filter: QKV layout - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) if qkv_format == "thd": if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") @@ -762,15 +763,15 @@ def get_attention_backend( "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) use_flash_attention = False - if ( - use_flash_attention - and _use_flash_attn_3 - and fp8 - and fp8_meta["recipe"].fp8_dpa - and "padding" in attn_mask_type - ): - logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") - _use_flash_attn_3 = False + #if ( + # use_flash_attention + # and _use_flash_attn_3 + # and fp8 + # and fp8_meta["recipe"].fp8_dpa + # and "padding" in attn_mask_type + #): + # logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") + # _use_flash_attn_3 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -6125,6 +6126,7 @@ def forward( # :batch_size # ] #) + print('fp88888888888', fp8) if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -6149,14 +6151,17 @@ def convert_to_torch_float8(tensor, dtype): query_layer, key_layer, value_layer = ( QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] ) - fa_3_optional_forward_kwargs["descale_q"] = ( - query_layer._scale_inv.unsqueeze(0) + batch_size = cu_seqlens_q.shape[0] - 1 + num_heads_q = query_layer.shape[-2] + num_heads_k = key_layer.shape[-2] + fa_3_optional_forward_kwargs["q_descale"] = ( + query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_q) ) - fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( + fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze( 0 - ) - fa_3_optional_forward_kwargs["descale_v"] = ( - value_layer._scale_inv.unsqueeze(0) + ).repeat(batch_size, num_heads_k) + fa_3_optional_forward_kwargs["v_descale"] = ( + value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) ) query_layer, key_layer, value_layer = ( convert_to_torch_float8(x, torch_dtype) From 24287938767dde678317521efef9cbd86eb21920 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 1 Mar 2025 19:59:06 -0800 Subject: [PATCH 146/239] WIP: add backend support table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dfe8959ba1..496bd46de7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -516,11 +516,14 @@ def get_attention_backend( use_unfused_attention = False # Filter: KV cache - # backend | precision - # ------------------------------------------------------------------------- - # FlashAttention | FP16/BF16 (non-paged/paged) - # FusedAttention | FP16/BF16 (non-paged/paged), FP8 (non-paged) - # UnfusedDotProductAttention | FP32/FP16/BF16 (non-paged/paged) + # backend | precision | KV cache | architecture | qkv_format + # -------------------------------------------------------------------------------------------- + # FusedAttention | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd + # | FP8 | non-paged | sm89+ | bshd,sbhd,thd + # FlashAttention v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd + # FlashAttention v3 | FP16/BF16 | non-paged/paged | sm80 | bshd,sbhd,thd + # | FP8 | non-paged/paged | sm80 | thd + # UnfusedDotProductAttention | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd if inference_params is not None: if context_parallel: logger.debug( @@ -6088,6 +6091,18 @@ def forward( ): func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: + # version | API | use cases + # ------------------------------------------------------------------------ + # FA v2 | flash_attn_func | bshd/sbhd + not padding + # | flash_attn_varlen_func | bshd/sbhd + padding + # | | thd + padding + not pad_between_seqs + # | | KV cache (not-paged/paged), i.e. + # | | bshd/sbhd/thd + padding + # FA v3 | flash_attn_func | bshd/sbhd + not padding + # | flash_attn_varlen_func | bshd/sbhd + padding + # | | thd + padding + not pad_between_seqs + # | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e. + # | | bshd/sbhd/thd + padding if not _use_flash_attn_3: func = flash_attn_varlen_func elif inference_params is None: From 91fa9029774bf7435a01e55252bfcf48155e547a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 1 Mar 2025 22:36:28 -0800 Subject: [PATCH 147/239] WIP: clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 14 +- tests/pytorch/fused_attn/test_paged_attn.py | 53 +-- .../common/fused_attn/fused_attn.cpp | 5 +- .../fused_attn_f16_arbitrary_seqlen.cu | 24 +- .../common/fused_attn/fused_attn_fp8.cu | 24 +- transformer_engine/common/fused_attn/utils.h | 20 +- transformer_engine/pytorch/attention.py | 425 +++++++----------- transformer_engine/pytorch/csrc/extensions.h | 9 +- .../pytorch/csrc/extensions/attention.cu | 19 +- transformer_engine/pytorch/inference.py | 100 +---- transformer_engine/pytorch/module/linear.py | 1 - 11 files changed, 238 insertions(+), 456 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 90e77ffb45..775bf1651e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -204,13 +204,13 @@ def test(): return available_backends, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - #with logging_context(): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + with logging_context(): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, fused_attn_backends diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index d2652b6ce3..8e7a6d4b8a 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -14,6 +14,7 @@ from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe +import transformer_engine.pytorch.fp8 as fp8 from transformer_engine.pytorch import fp8_autocast, fp8_model_init from transformer_engine.pytorch.transformer import ( TransformerLayer, @@ -35,6 +36,7 @@ _get_attention_backends, ) from tests.pytorch.test_numerics import assert_allclose +fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() # Initialize RNG state seed = 1234 @@ -47,6 +49,8 @@ param_types = [torch.float16] if is_bf16_compatible(): param_types.append(torch.bfloat16) +if fp8_available: + param_types.append(torch.float8_e4m3fn) model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -93,8 +97,6 @@ def __init__( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") - # self.context_lens[0] = 2 - # self.context_lens[2] = 3 # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -113,7 +115,6 @@ def __init__( self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to( dtype=torch.int32, device="cpu" ) - # self.arrival_times[2] = 0 # self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() @@ -220,6 +221,8 @@ def get_model( qkv_format: str = "bshd", num_layers: int = 1, mode: str = "reference", + fp8_dpa: bool = False, + fp8_mha: bool = False, ): reset_rng_states() sigma = 0.023 @@ -232,13 +235,12 @@ def get_model( if mode == "inference": attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding" - fp8_mha = True fp8_recipe = recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.HYBRID, amax_history_len=1, amax_compute_algo="most_recent", - fp8_dpa=fp8_mha, + fp8_dpa=fp8_dpa, fp8_mha=fp8_mha, ) @@ -256,8 +258,7 @@ def get_model( output_layer_init_method=output_layer_init_method, layer_number=layer_number, kv_channels=config.head_dim_qk, - self_attn_mask_type=attn_mask_type, # "padding", #_causal", - # enc_dec_attn_mask_type="padding", #_causal", + self_attn_mask_type=attn_mask_type, params_dtype=dtype, attn_input_format=qkv_format, ) @@ -395,8 +396,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 config = model_configs_infer[model] - #dtype = torch.float8_e4m3fn - is_fp8 = True + is_fp8 = dtype == torch.float8_e4m3fn + if is_fp8: + dtype = torch.bfloat16 # figure out supported backends inference_params_qkv_format = "bshd" @@ -412,7 +414,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda pad_between_seqs=False, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - print('available_backends', available_backends) if backend == "FlashAttention" and not flash_attn_supported: pytest.skip("FlashAttention backend is not supported") if backend == "FusedAttention" and not fused_attn_supported: @@ -430,9 +431,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 config.max_seqlen_kv = 256 - print('qkv format', qkv_format) - #if dtype == torch.float8_e4m3fn and qkv_format != "thd": - if is_fp8 and (qkv_format != "thd" or module != "DotProductAttention"):# or dtype != torch.float16): + if is_fp8 and (qkv_format != "thd" or module != "DotProductAttention"): pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") # create full model @@ -452,18 +451,12 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda full_output = m( *full_output if isinstance(full_output, List) else full_output, ) - # rotary_freqs = torch.randn((config.max_seqlen_kv, 1, 1, config.num_heads), dtype=torch.float, device="cuda") if module == "TransformerLayer": full_output = full_inputs for m in model: - print("xxxxxxxxxxxxxxxxxxxxxxxx ", type(full_output)) full_output = m( full_output[0] if isinstance(full_output, List) else full_output, - # rotary_pos_emb=rotary_freqs, ) - print("full", full_output[0, :2, :8]) - print("full", full_output[1, :7, :8]) - print("full", full_output[2, :3, :8]) # simulate real-life inference logger.info("=== Generating one token at a time ===") @@ -500,22 +493,20 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda head_dim_q=config.head_dim_qk, max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, - # allow_query_conversion=backend != "FusedAttention", ) for layer_number in range(1, num_layers + 1): inference_params.allocate_memory(layer_number, qkv_format) # create inference model - model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="inference") + model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="inference", fp8_dpa=is_fp8, fp8_mha=is_fp8) - fp8_mha = True fp8_recipe = recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.HYBRID, amax_history_len=1, amax_compute_algo="most_recent", - fp8_dpa=fp8_mha, - fp8_mha=fp8_mha, + fp8_dpa=is_fp8, + fp8_mha=is_fp8, ) # graph the model if necessary if is_cuda_graph: @@ -546,13 +537,13 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): model = [ make_graphed_callables( model[i], sample_args, num_warmup_iters=10, - fp8_enabled=fp8_mha, #False, + fp8_enabled=is_fp8, sample_kwargs=sample_kwargs, fp8_recipe=fp8_recipe, ) @@ -648,9 +639,8 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if inference_params.is_paged: inference_params.cache_manager.print_cache() incremental_output = incremental_inputs - with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): for m in model: - print("xxxxdgdg ", type(incremental_output)) incremental_output = m( *incremental_output if isinstance(incremental_output, List) else incremental_output, cu_seqlens_q=cu_seqlens_q, @@ -659,20 +649,17 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda max_seqlen_q=max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, ) - print('ddddddddddd ', len(incremental_output), type(incremental_output)) incremental_output = [incremental_output] incremental_output = incremental_output[0] # compare results - tol = get_tols(module, backend, dtype=torch.float8_e4m3fn) + tol = get_tols(module, backend, dtype=dtype) for i, seq in enumerate(sim.t_seq_ids): - # token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1 token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": print(i, seq, sim.t_total_lens, sim.step_lens, token_index) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(incremental_output[i, token_index, :4]) - # print(incremental_output[i, sim.step_lens[i] - 1, :4]) torch.testing.assert_close( # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], # incremental_output[:sim.step_lens[i] - 1, i, :], @@ -707,8 +694,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ) sim.t += 1 sim.t_gen_lens = sim.t_gen_lens + 1 - # if sim.t == 1: - # break sim.serving_times = sim.arrival_times + sim.request_delays sim.complete_times = sim.serving_times + sim.gen_lens diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 8f00e1f644..b7620931ef 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -273,12 +273,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - //max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) (cudnn_runtime_version >= 90600 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - //max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && @@ -330,7 +330,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (supported_ragged_offset_size)) { flag_arb = true; } - flag_arb = true; if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e12122f822..e7a60f11ab 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -85,7 +85,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { - NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); + NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); } // keep original batch size because cu_seqlens are created with [b+1] shape @@ -108,10 +108,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_kv, d_qk, d_v, - num_pages_k, - num_pages_v, - page_size_k, - page_size_v, + //num_pages_k, + //num_pages_v, + //page_size_k, + //page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, @@ -516,7 +516,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { - NVTE_CHECK(is_padding, "Paged attention requires padding masks!"); + NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); } // keep original batch size because cu_seqlens are created with [b+1] shape @@ -542,12 +542,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( s_kv, d_qk, d_v, - 0, - 0, - 0, - 0, - 0, - 0, + //0, + //0, + //0, + //0, + 1, + 1, bias_b, bias_h, scaling_factor, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index eacd8b53b4..88883a3ed0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1679,12 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1( s_kv, d, d, - 0, - 0, - 0, - 0, - 0, - 0, + //0, + //0, + //0, + //0, + 1, + 1, bias_b, bias_h, scaling_factor, @@ -1983,12 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1( s_kv, d, d, - 0, - 0, - 0, - 0, - 0, - 0, + //0, + //0, + //0, + //0, + 1, + 1, bias_b, bias_h, scaling_factor, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 30702a875d..4766f80e34 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -93,10 +93,10 @@ struct FADescriptor_v1 { std::int64_t s_kv; std::int64_t d_qk; std::int64_t d_v; - std::int64_t num_pages_k; - std::int64_t num_pages_v; - std::int64_t page_size_k; - std::int64_t page_size_v; + //std::int64_t num_pages_k; + //std::int64_t num_pages_v; + //std::int64_t page_size_k; + //std::int64_t page_size_v; std::int64_t max_pages_per_seq_k; std::int64_t max_pages_per_seq_v; std::int64_t bias_b; @@ -114,13 +114,15 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, + //num_pages_k, num_pages_v, page_size_k, page_size_v, + max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, - rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, + //rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, + rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, + rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 496bd46de7..9c236d1b58 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Attention""" +"""Attention.""" import collections from contextlib import nullcontext from importlib.metadata import version as get_pkg_version @@ -122,14 +122,13 @@ def _get_supported_versions(version_min, version_max): _flash_attn_max_version = PkgVersion("2.7.4.post1") _flash_attn_2_plus = False _flash_attn_2_1_plus = False -_flash_attn_2_2_plus = False _flash_attn_2_3_plus = False _flash_attn_2_4_plus = False _flash_attn_2_4_1_plus = False _flash_attn_2_5_plus = False _flash_attn_2_5_7_plus = False -_flash_attn_2_6_plus = False -_flash_attn_2_7_plus = False +_flash_attn_2_6_0_plus = False +_flash_attn_2_7_0_plus = False flash_attn_cuda_bwd = None flash_attn_func = None @@ -168,16 +167,13 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") - _flash_attn_2_2_plus = _flash_attn_version >= PkgVersion("2.2") - if _flash_attn_2_2_plus: - from flash_attn.flash_attn_interface import flash_attn_with_kvcache _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_plus = _flash_attn_version >= PkgVersion("2.5.0") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") - _flash_attn_2_6_plus = _flash_attn_version >= PkgVersion("2.6.0") - _flash_attn_2_7_plus = _flash_attn_version >= PkgVersion("2.7.0") + _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") + _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0") elif ( torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN ): @@ -194,48 +190,45 @@ def _get_supported_versions(version_min, version_max): _flash_attn_version, ) -# Detect flash_attn_3 in the environment +# Detect flash-attn v2 in the environment (Hopper only) _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False _use_flash_attn_3 = False _flash_attn_3_installation_steps = """\ - git clone https://github.com/Dao-AILab/flash-attention.git - cd flash-attention/ - git checkout 39e7197 - cd hopper/ - python setup.py install - python_path=`python -c "import site; print(site.getsitepackages()[0])"` - mkdir -p $python_path/flash_attn_3 - wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/refs/heads/main/hopper/flash_attn_interface.py""" -try: - _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) -except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( - "flash-attn v3 is not installed. To use, please install it by \n%s", - _flash_attn_3_installation_steps, +(1) git clone https://github.com/Dao-AILab/flash-attention.git +(2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install +(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(4) mkdir -p $python_path/flash_attn_3 +(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/refs/heads/main/hopper/flash_attn_interface.py""" +if torch.cuda.is_available() and get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: + try: + _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) + except PackageNotFoundError: + fa_logger.debug( + "flash-attn v3 is not installed. To use, please install it by \n%s", + _flash_attn_3_installation_steps, + ) + else: + from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_3.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, ) -else: - from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_3.flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_v3, - ) - from flash_attn_3.flash_attn_interface import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_v3, - ) - from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - #from flash_attn_3.flash_attn_interface import ( - # _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, - #) - #from flash_attn_3.flash_attn_interface import ( - # _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, - #) - - _flash_attn_3_is_installed = True - _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") - _use_flash_attn_3 = True + from flash_attn_3.flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + #from flash_attn_3.flash_attn_interface import ( + # _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, + #) + #from flash_attn_3.flash_attn_interface import ( + # _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, + #) + + _flash_attn_3_is_installed = True + _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") + _use_flash_attn_3 = True _attention_backends = { "attention_params": None, @@ -302,7 +295,7 @@ class AttentionParams: Whether `DotProductAttention` is in an `fp8_autocast` region. fp8_meta: Optional[Dict[str Any]], default = `None` The FP8 metadata tensor of `DotProductAttention`. - inference_params: Optional[object], default = `None` + inference_params: Optional[InferenceParams], default = `None` Inference-related parameters. See InferenceParams for details. """ @@ -329,7 +322,7 @@ class AttentionParams: is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None - inference_params: Optional[object] = None + inference_params: Optional[InferenceParams] = None def __eq__(self, other): """ @@ -452,6 +445,7 @@ def get_attention_backend( # install an appropriate FA version. global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 + # get q/kv format qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables @@ -483,14 +477,14 @@ def get_attention_backend( torch.Tensor, Float8Tensor, ]: - #if use_flash_attention and _flash_attn_is_installed: - # logger.debug( - # "Disabling FlashAttention due to unsupported QKV data type. " - # "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - # "Found: qkv_dtype = %s.", - # qkv_dtype, - # ) - #use_flash_attention = False + if use_flash_attention and _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_flash_attention = False if use_fused_attention: logger.debug( "Disabling FusedAttention due to unsupported QKV data type. " @@ -533,27 +527,27 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and q_format != "thd": + if use_flash_attention: use_flash_attention = False - logger.debug("Disabling FlashAttention for FP8 KV caching (only THD is supported)") + logger.debug("Disabling FlashAttention for FP8 KV caching") + if _use_flash_attn_3 and q_format != "thd": + _use_flash_attn_3 = False + logger.debug("Disabling FlashAttention 3 for FP8 KV caching in non-THD") if use_fused_attention and inference_params.is_paged: use_fused_attention = False logger.debug( - "Disabling FusedAttention as it does not support paged attention in FP8" + "Disabling FusedAttention for paged attention in FP8" ) if use_unfused_attention: use_unfused_attention = False - logger.debug("Disabling UnfusedAttention as it does not support FP8 attention") + logger.debug("Disabling UnfusedAttention for FP8 attention") else: - if use_flash_attention and not _flash_attn_2_2_plus and not _use_flash_attn_3: - use_flash_attention = False - logger.debug( - "Disabling FlashAttention as KV caching requires flash-attn 2.2+, or 3.0" - " (Hopper only)" - ) if use_fused_attention and pad_between_seqs: use_fused_attention = False logger.debug("Disabling FusedAttention for pad_between_seqs = True and KV caching") + if use_flash_attention and pad_between_seqs: + use_flash_attention = False + logger.debug("Disabling FlashAttention for pad_between_seqs = True and KV caching") if inference_params.is_paged: if use_fused_attention and cudnn_version < (9, 5, 0): logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+") @@ -766,15 +760,6 @@ def get_attention_backend( "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) use_flash_attention = False - #if ( - # use_flash_attention - # and _use_flash_attn_3 - # and fp8 - # and fp8_meta["recipe"].fp8_dpa - # and "padding" in attn_mask_type - #): - # logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") - # _use_flash_attn_3 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -1061,7 +1046,13 @@ def get_attention_backend( @torch.no_grad() -def get_attn_mask(batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): +def get_padding_mask( + batch_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_q: int, + max_seqlen_kv: int, + ): """Convert cu_seqlens to attention_mask""" seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] @@ -1209,6 +1200,7 @@ def get_full_mask( m = attention_mask.logical_not() actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + # apply SWA mask mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 @@ -2051,7 +2043,7 @@ def forward( if use_fused_attention: softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) else: - softmax_lse_in_packed_format = _flash_attn_2_6_plus or _use_flash_attn_3 + softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3 flash_attn_fwd = None if not use_fused_attention: @@ -2070,16 +2062,16 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if (_flash_attn_2_3_plus and not _flash_attn_2_7_plus) or _use_flash_attn_3: + if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = 0 if causal else -1 if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_plus: + if _flash_attn_2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 # Flash Attn inputs @@ -2247,7 +2239,7 @@ def forward( causal=True, **fa_forward_kwargs, ) - if not _flash_attn_2_7_plus: + if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2361,10 +2353,10 @@ def forward( max_seqlen_kv // 2, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( @@ -2383,7 +2375,7 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_plus: + if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2506,10 +2498,10 @@ def forward( max_seqlen_kv, ] if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_plus + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( @@ -2528,7 +2520,7 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_plus: + if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -2642,7 +2634,7 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_plus: + if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -3026,7 +3018,7 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_plus: + if _flash_attn_2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): @@ -3160,9 +3152,9 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, 0) - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 if not _use_flash_attn_3: @@ -3275,9 +3267,9 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - if _flash_attn_2_7_plus: + if _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: @@ -3392,9 +3384,9 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: @@ -3486,9 +3478,9 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: @@ -3841,7 +3833,7 @@ def forward( fa_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_plus: + if _flash_attn_2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 assert qkv_format != "thd", f"{qkv_format} format is not supported!" @@ -3951,9 +3943,9 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( @@ -3964,7 +3956,7 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not _flash_attn_2_7_plus: + if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] if not _use_flash_attn_3: @@ -4088,7 +4080,7 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_plus: + if _flash_attn_2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): @@ -4147,9 +4139,9 @@ def backward(ctx, dout): ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if _flash_attn_2_3_plus and not _flash_attn_2_7_plus: + if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] - if _flash_attn_2_7_plus: + if _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( @@ -4303,16 +4295,16 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size[0] fa_forward_kwargs["window_size_right"] = window_size[1] if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_plus: + if _flash_attn_2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 assert ( @@ -4424,7 +4416,7 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not _flash_attn_2_7_plus: + if not _flash_attn_2_7_0_plus: out, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not _use_flash_attn_3 else None else: @@ -4613,16 +4605,16 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_plus): + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = ctx.window_size - elif _flash_attn_2_7_plus: + elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = ctx.window_size[0] fa_backward_kwargs["window_size_right"] = ctx.window_size[1] if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_plus: + if _flash_attn_2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: @@ -5281,7 +5273,7 @@ def forward( if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged( self.layer_number - ) # , inference_params.input_qkv_format + ) if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now @@ -5299,8 +5291,6 @@ def forward( cu_seqlens_q, batch_size, inference_params.max_ctx_len, - query_layer.shape[-2], - query_layer.shape[-1], ) query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] @@ -5311,8 +5301,8 @@ def forward( key_layer.shape[0], ) - if "padding" in attn_mask_type: # and qkv_format in ["bshd", "sbhd"]: - attention_mask = get_attn_mask( + if "padding" in attn_mask_type and attention_mask is None: + attention_mask = get_padding_mask( batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv ) attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( @@ -5458,7 +5448,7 @@ def forward( # [b, sq, np, hn] --> [b, sq, hp] context_layer = context_layer.view(batch_size, seqlen, -1) - if qkv_format == "thd_2bshd": + if q_format == "thd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -5466,10 +5456,6 @@ def forward( context_layer = tex.convert_bshd_to_thd( context_layer, cu_seqlens_q, - batch_size, - inference_params.max_ctx_len, - context_layer.shape[-2], - context_layer.shape[-1], total_tokens, ) @@ -5518,7 +5504,7 @@ def get_qkv_format( qkv_layout: str = "bshd_bshd_bshd", inference_params: InferenceParams = None, ) -> str: - """Get qkv layout. + """Get qkv format. Parameters ---------- @@ -5532,24 +5518,17 @@ def get_qkv_format( qkv_format: str, default = `sbhd` Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. q_format: str - Format of the query tensor, {`bshd`, `sbhd`, `thd`}. + Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}. kv_format: str - Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. + Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}. """ - if inference_params is not None: # "_2" in qkv_layout: - # qkv_format = qkv_layout.replace("paged_kv_", "") - # q_format, kv_format = qkv_format.split("_2") - q_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) - kv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[1] if i.isalpha()] - ) + splited = qkv_layout.replace("paged_kv_", "").split("_") + if inference_params is not None: + q_format = "".join([i for i in splited[0] if i.isalpha()]) + kv_format = "".join([i for i in splited[1] if i.isalpha()]) qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format else: - qkv_format = "".join( - [i for i in qkv_layout.replace("paged_kv_", "").split("_")[0] if i.isalpha()] - ) + qkv_format = "".join([i for i in splited[0] if i.isalpha()]) q_format = qkv_format kv_format = qkv_format return qkv_format, q_format, kv_format @@ -5830,6 +5809,7 @@ def forward( inference_params: Optional[InferenceParams] = None, ) -> torch.Tensor: """flash-attn fprop""" + assert all( x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] @@ -5852,8 +5832,7 @@ def forward( # get q_format and kv_format for training and inference qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) - # convert q, k, v to bshd if they are in sbhd - # qkv_format is unchanged + # convert q, k, v to bshd if they are in sbhd; qkv_format doesn't change if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if qkv_format == "sbhd": # For now just 128, will make it more general in the future @@ -5896,9 +5875,6 @@ def forward( x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - print( - "FA 1", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout - ) # get accurate batch_size, max_seqlen and cu_seqlens batch_size, context_len, total_tokens = None, None, None if inference_params is None: @@ -5964,7 +5940,7 @@ def forward( max_seqlen_kv, key_layer.device, ) - if qkv_format == "thd": + elif qkv_format == "thd": assert ( cu_seqlens_q is not None and cu_seqlens_kv is not None ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" @@ -5976,19 +5952,15 @@ def forward( max_seqlen_kv = seqlens_kv.max().item() else: if qkv_format in ["sbhd_2bshd", "bshd"]: - # q is in bshd in both cases (conversion above or original input) + # q is in bshd in both cases from conversion above or the original input batch_size, context_len, num_heads, head_dim = query_layer.shape cu_seqlens_q = cu_seqlens_q[: batch_size + 1] cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] - # convert to thd_2bshd + # convert from bshd to thd_2bshd if isinstance(query_layer, Float8Tensor): query_layer._data = tex.convert_bshd_to_thd( query_layer._data, cu_seqlens_q, - batch_size, - context_len, - num_heads, - head_dim, batch_size * context_len, ) query_layer = Float8Tensor.make_like( @@ -5998,43 +5970,9 @@ def forward( query_layer = tex.convert_bshd_to_thd( query_layer, cu_seqlens_q, - batch_size, - context_len, - num_heads, - head_dim, batch_size * context_len, ) - #if _use_flash_attn_3 and qkv_format in ["thd_2bshd"]: - # total_tokens, num_heads, head_dim = query_layer.shape - # batch_size = key_layer.shape[0] - # cu_seqlens_q = cu_seqlens_q[: batch_size + 1] - # cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] - # # convert to bshd_2bshd for flash_attn_with_kvcache_v3 - # if isinstance(query_layer, Float8Tensor): - # query_layer._data = tex.convert_thd_to_bshd( - # query_layer._data, - # cu_seqlens_q, - # batch_size, - # inference_params.max_ctx_len, - # num_heads, - # head_dim, - # ) - # query_layer = Float8Tensor.make_like( - # query_layer, data=query_layer._data, shape=query_layer._data.shape - # ) - # else: - # query_layer = tex.convert_thd_to_bshd( - # query_layer, - # cu_seqlens_q, - # batch_size, - # inference_params.max_ctx_len, - # num_heads, - # head_dim, - # ) - - print( - "FA 1", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format, qkv_layout - ) + if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -6076,33 +6014,25 @@ def forward( tensor.activation_offloading = True with self.attention_dropout_ctx(): - fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = window_size - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes - if _flash_attn_2_4_1_plus: - fa_optional_forward_kwargs["deterministic"] = self.deterministic + # | API | use cases + # ---------------------------------------------------------------------- + # FA v2 | flash_attn_func | bshd/sbhd + not padding + # | flash_attn_varlen_func | bshd/sbhd + padding + # | | thd + padding + not pad_between_seqs + # | | KV cache (not-paged/paged), i.e. + # | | bshd/sbhd/thd + padding + # FA v3 | flash_attn_func | bshd/sbhd + not padding + # | flash_attn_varlen_func | bshd/sbhd + padding + # | | thd + padding + not pad_between_seqs + # | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e. + # | | bshd/sbhd/thd + padding fa_optional_forward_args_thd = [] if ( qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type - and inference_params is None ): func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: - # version | API | use cases - # ------------------------------------------------------------------------ - # FA v2 | flash_attn_func | bshd/sbhd + not padding - # | flash_attn_varlen_func | bshd/sbhd + padding - # | | thd + padding + not pad_between_seqs - # | | KV cache (not-paged/paged), i.e. - # | | bshd/sbhd/thd + padding - # FA v3 | flash_attn_func | bshd/sbhd + not padding - # | flash_attn_varlen_func | bshd/sbhd + padding - # | | thd + padding + not pad_between_seqs - # | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e. - # | | bshd/sbhd/thd + padding if not _use_flash_attn_3: func = flash_attn_varlen_func elif inference_params is None: @@ -6114,8 +6044,16 @@ def forward( fa_optional_forward_args_thd.append(cu_seqlens_kv) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) + if not _use_flash_attn_3: + fa_optional_forward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes + if _flash_attn_2_4_1_plus: + fa_optional_forward_kwargs["deterministic"] = self.deterministic if inference_params is not None: - # use page_table to support thd_2bshd format when is_paged=False + # use block_table to support thd_2bshd format for non-paged fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged @@ -6123,7 +6061,17 @@ def forward( :batch_size ] ) - if _use_flash_attn_3: + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, + ) + else: fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size if inference_params is None: @@ -6135,13 +6083,6 @@ def forward( fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens if inference_params.is_paged: fa_3_optional_forward_kwargs["page_table"] = inference_params.cache_manager.page_table[:batch_size] - #fa_3_optional_forward_kwargs["page_table"] = ( - # inference_params.cache_manager.page_table[:batch_size] - # else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[ - # :batch_size - # ] - #) - print('fp88888888888', fp8) if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -6209,32 +6150,19 @@ def convert_to_torch_float8(tensor, dtype): if fp8 and fp8_meta["recipe"].fp8_mha: O_quantizer = quantizers["scaling_fwd"][META_O] output = O_quantizer(output) - else: - output = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) if inference_params is None: if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) elif qkv_format in ["bshd", "sbhd_2bshd"]: - # convert back to bshd_2bshd from thd_2bshd - # batch_size, context_len, num_heads, head_dim = output.shape + # all KV caching cases use thd_2bshd for calculation + # convert results from thd_2bshd back to bshd if isinstance(query_layer, Float8Tensor): output._data = tex.convert_thd_to_bshd( output._data, cu_seqlens_q, batch_size, context_len, - num_heads, - head_dim, ) output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape) else: @@ -6243,36 +6171,7 @@ def convert_to_torch_float8(tensor, dtype): cu_seqlens_q, batch_size, context_len, - num_heads, - head_dim, ) - #elif _use_flash_attn_3 and qkv_format in ["thd_2bshd"]: - # # if flash_attn_2, use flash_attn_varlen_func_v3 and thd_2bshd - # # if flash_attn_3, use flash_attn_with_kvcache_v3 and bshd - # #batch_size = cu_seqlens_q.shape[0] - 1 - # #total_tokens, num_heads, head_dim = query_layer.shape - # # convert back to thd_2bshd from bshd - # if isinstance(query_layer, Float8Tensor): - # output._data = tex.convert_bshd_to_thd( - # output._data, - # cu_seqlens_q, - # batch_size, - # inference_params.max_ctx_len, - # num_heads, - # head_dim, - # total_tokens, - # ) - # output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape) - # else: - # output = tex.convert_bshd_to_thd( - # output, - # cu_seqlens_q, - # batch_size, - # inference_params.max_ctx_len, - # num_heads, - # head_dim, - # total_tokens, - # ) if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) @@ -7756,20 +7655,16 @@ def forward( for x in [query_layer, key_layer, value_layer] ] - # update KV cache and retrieve full KV tokens ( - # query_layer, key_layer, value_layer, page_table, cu_seqlens_q, cu_seqlens_kv, - # max_seqlen_q, max_seqlen_kv, qkv_format, ) = inference_params.step( self.layer_number, - # query_layer, key_layer, value_layer, qkv_format, @@ -7777,9 +7672,6 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None - print( - "FA DPA", [x.shape for x in [query_layer, key_layer, value_layer]], qkv_format - ) # , qkv_layout) # get accurate qkv_layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( @@ -8129,13 +8021,6 @@ def forward( inference_params=inference_params, ) - # if inference_params is not None: - ## inference_params.is_output_right_aligned = use_flash_attention - # output = inference_params.post_step(self.layer_number, output) - # print(output[0,-2:,:8]) - # print(output[1,:,:8]) - # print(output[2,-3:,:8]) - return output @@ -8850,8 +8735,6 @@ def forward( f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE." ) - # pylint: disable=fixme - # TODO: consider cases where sequences have different seqlens # sequence_start = inference_params.get_seqlens_pre_step() sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f73dcf140f..ef0188fb36 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -70,14 +70,11 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, - int h, int d); -at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, - int h, int d, int t); +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, - torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int h_kv, int d_k, - int d_v, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, + torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 66f27f87c6..00221e0bab 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1043,8 +1043,9 @@ void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at:: b, max_seq_len, h, d); } -at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, - int h, int d) { +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) { + int h = tensor.size(1); + int d = tensor.size(2); std::vector shape = {b, max_seq_len, h, d}; at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); if (new_tensor.scalar_type() == at::ScalarType::Half) { @@ -1082,8 +1083,11 @@ void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at:: b, max_seq_len, h, d); } -at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len, - int h, int d, int t) { +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { + int b = tensor.size(0); + int max_seq_len = tensor.size(1); + int h = tensor.size(2); + int d = tensor.size(3); std::vector shape = {t, h, d}; at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); if (tensor.scalar_type() == at::ScalarType::Half) { @@ -1149,8 +1153,11 @@ void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, - NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b, - int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged) { + NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq, bool is_non_paged) { + int h_kv = new_k.size(-2); + int d_k = new_k.size(-1); + int d_v = new_v.size(-1); NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && new_k.scalar_type() == new_v.scalar_type() && new_k.scalar_type() == k_cache.scalar_type(), diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 67c00dc781..8bffb5c2cf 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Inference.""" +"""Inference""" import os import logging from collections import OrderedDict, defaultdict @@ -90,9 +90,6 @@ class DotProductAttention: q, k_cache, v_cache, qkv_format = inference_params.step( new_q, new_k, new_v, qkv_format) output = attention(q, k_cache, v_cache, new_qkv_format) - if inference_params is not None: - output = inference_params.post_step(output) - return output InferenceParams supports cache_qkv_format = "bshd" only, and the step() function may change qkv_format depending on the attention backend. @@ -152,7 +149,6 @@ def __init__( max_ctx_len: int = None, qkv_format: str = "bshd", cache_manager: KVCacheManager = None, - # allow_query_conversion: bool = True, ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv @@ -164,9 +160,6 @@ def __init__( _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - # self.allow_query_conversion = allow_query_conversion and ( - # _NVTE_FLASH_ATTN or _NVTE_UNFUSED_ATTN or not _NVTE_FUSED_ATTN - # ) if not self.is_paged: cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager @@ -206,18 +199,11 @@ def __init__( if qkv_format == "thd": assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" self.max_ctx_len = max_ctx_len - # if self.allow_query_conversion: - # # query is converted to 'bshd' for certain backends - # assert num_heads_q is not None, "num_heads_q is required when qkv_format=thd!" - # assert head_dim_q is not None, "head_dim_q is required when qkv_format=thd!" - # self.num_heads_q = num_heads_q - # self.head_dim_q = head_dim_q - # self.max_seqlen_q = max_ctx_len # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: + if self.input_qkv_format == self.cache_qkv_format: self.output_qkv_format = self.cache_qkv_format else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format @@ -234,9 +220,6 @@ def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() self.cache_manager.reset() - # if self.input_qkv_format == "thd" and self.allow_query_conversion: - # for _, q_buffer in self.q_buffer.items(): - # q_buffer.fill_(0) def __repr__(self) -> str: if self.is_paged: @@ -282,16 +265,6 @@ def allocate_memory(self, layer_number: int, qkv_format: str): device=torch.cuda.current_device(), ) - # if qkv_format == "thd" and self.allow_query_conversion: - # self.q_buffer[layer_number] = torch.zeros( - # self.max_batch_size, - # self.max_ctx_len, - # self.num_heads_q, - # self.head_dim_q, - # dtype=self.dtype, - # device=torch.cuda.current_device(), - # ) - def pre_step( self, step_dict: OrderedDict, @@ -323,7 +296,7 @@ def get_seqlens_pre_step(self): """Get cached sequence lengths for current iteration before adding step_dict.values""" return self.sequences_pre - def convert_paged_to_nonpaged(self, layer_number: int): # , qkv_format: str): + def convert_paged_to_nonpaged(self, layer_number: int): """ Convert k_cache and v_cache from paged to non-paged format. This is used by the UnfusedDotProductAttention backend. Both k_cache and v_cache are assumed to be @@ -364,7 +337,6 @@ def convert_paged_to_nonpaged(self, layer_number: int): # , qkv_format: str): def step( self, layer_number: int, - # new_q: torch.Tensor, new_k: torch.Tensor, new_v: torch.Tensor, qkv_format: str, @@ -402,36 +374,12 @@ def step( qkv_format: str Updated qkv_format, e.g. the input 'thd' format may become 'thd_2bshd' after step() """ - print("self.sequences", self.sequences) self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format: # or self.allow_query_conversion: + if self.input_qkv_format == self.cache_qkv_format: self.output_qkv_format = self.cache_qkv_format else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - # q_buffer = new_q - # if qkv_format == "bshd": - # self.max_seqlen_q = new_q.shape[1] - # q_buffer = new_q.contiguous() - # if qkv_format == "sbhd": - # self.max_seqlen_q = new_q.shape[0] - # if self.allow_query_conversion: - # q_buffer = new_q.transpose(0, 1).contiguous() - # if qkv_format == "thd": - # self.max_seqlen_q = self.max_ctx_len - # if self.allow_query_conversion: - # q_buffer = self.q_buffer[layer_number] - # tex.convert_thd_to_bshd( - # new_q, - # self.q_buffer[layer_number], - # self.cu_seqlens_q, - # self.max_batch_size, - # self.max_ctx_len, - # self.num_heads_q, - # self.head_dim_q, - # ) - # self.q_orig[layer_number] = new_q - k_cache, v_cache, page_table = self.cache_manager.step( layer_number, new_k, @@ -442,46 +390,15 @@ def step( ) return ( - # q_buffer, k_cache, v_cache, page_table, self.cu_seqlens_q, self.cu_seqlens_kv, - # self.max_seqlen_q, self.max_seqlen_kv, self.output_qkv_format, ) - def post_step( - self, - layer_number: int, - output: torch.Tensor, - ): - """ - Process the attention output in order to return it to the original qkv_format. - """ - print("post step ", self.input_qkv_format) - # if self.input_qkv_format == "bshd": - # output = output[: self.batch_size, : self.max_seqlen_q].contiguous() - # if self.input_qkv_format == "sbhd": # and self.allow_query_conversion: - # output = output[: self.batch_size, : self.max_seqlen_q].transpose(0, 1).contiguous() - # if self.input_qkv_format == "thd": # and self.allow_query_conversion: - # output_buffer = self.q_orig[layer_number] - # tex.convert_bshd_to_thd( - # output, - # output_buffer, - # self.cu_seqlens_q, - # self.batch_size, - # self.max_ctx_len, - # self.num_heads_q, - # self.head_dim_q, - # self.total_tokens, - # ) - # output = output_buffer.view(output_buffer.shape[0], -1) - - return output - class NonPagedKVCacheManager(KVCacheManager): """Non-paged KV cache manager""" @@ -537,6 +454,7 @@ def allocate_memory(self, layer_number): dtype=torch.int32, device=torch.cuda.current_device(), ) + # always in [0, ..., b-1] fashion due to reindexing self.batch_indices_post = torch.range( 0, self.max_batch_size - 1, @@ -567,7 +485,6 @@ def pre_step( ) ).to(dtype=torch.int32, device="cpu") ) - print("self.batch_indices", self.batch_indices) # Advance unfinished sequences for i in unfinished_seqs: @@ -631,7 +548,6 @@ def step( batch_size = new_k.shape[1] ctx_len = new_k.shape[0] - # print('non-paged self.batch_indices', self.batch_indices) tex.copy_to_kv_cache( new_k, new_v, @@ -641,9 +557,6 @@ def step( cu_new_seqlens, cu_cached_seqlens, QKVFormat[qkv_format], - self.num_heads, - self.head_dim_k, - self.head_dim_v, batch_size, ctx_len, self.max_seqlen, @@ -906,9 +819,6 @@ def step( cu_new_seqlens, cu_cached_seqlens, QKVFormat[qkv_format], - self.num_heads, - self.head_dim_k, - self.head_dim_v, batch_size, ctx_len, self.max_seqlen, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 29cefa2586..83dc652c62 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -111,7 +111,6 @@ def forward( # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape - print('inp_shape', inp_shape, weight.shape) assert inp_shape[-1] == in_features, "GEMM not possible" tp_world_size = get_distributed_world_size(tp_group) From eeb0dc72ea7e62dfa1c57b0212d2a70c57de5674 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 2 Mar 2025 16:20:02 -0800 Subject: [PATCH 148/239] WIP: separate use_flash_attention_2 and _3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 15 +- transformer_engine/pytorch/attention.py | 528 ++++++++++++-------- 2 files changed, 336 insertions(+), 207 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 775bf1651e..6b333429cb 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -139,6 +139,7 @@ def _get_attention_backends( pad_between_seqs: bool = False, context_parallel: bool = False, deterministic: bool = False, + is_training: bool = True, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, inference_params: Optional[InferenceParams] = None, @@ -171,6 +172,7 @@ def _get_attention_backends( fused_attn_backends = [] available_backends = None + flash_attention_backend = None fused_attention_backend = None def test(): @@ -194,24 +196,25 @@ def test(): attention_dropout=config.dropout_p, context_parallel=context_parallel, deterministic=deterministic, + is_training=is_training, fp8=fp8, fp8_meta=fp8_meta, inference_params=inference_params, ) - _, _, fused_attention_backend, _, available_backends = get_attention_backend( + _, _, flash_attention_backend, fused_attention_backend, _, available_backends = get_attention_backend( attention_params ) - return available_backends, fused_attention_backend + return available_backends, flash_attention_backend, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} with logging_context(): for i in range(3): os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() + available_backends, flash_attention_backend, fused_attention_backend = test() if fused_attention_backend == FusedAttnBackend[backends[i]]: fused_attn_backends.append(fused_attention_backend) - return available_backends, fused_attn_backends + return available_backends, flash_attention_backend, fused_attn_backends model_configs_base = { @@ -263,7 +266,7 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -1127,7 +1130,7 @@ def test_transformer_layer( workspace_opt = True # Test backend availability - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9c236d1b58..813438ffc0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -190,17 +190,17 @@ def _get_supported_versions(version_min, version_max): _flash_attn_version, ) -# Detect flash-attn v2 in the environment (Hopper only) +# Detect flash-attn v3 in the environment (Hopper only) _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False -_use_flash_attn_3 = False +#_use_flash_attn_3 = False _flash_attn_3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 -(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/refs/heads/main/hopper/flash_attn_interface.py""" +(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" if torch.cuda.is_available() and get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: try: _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) @@ -228,11 +228,12 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_is_installed = True _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") - _use_flash_attn_3 = True + #_use_flash_attn_3 = True _attention_backends = { "attention_params": None, "use_flash_attention": None, + "flash_attention_backend": None, "use_fused_attention": None, "fused_attention_backend": None, "use_unfused_attention": None, @@ -443,17 +444,22 @@ def get_attention_backend( # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. - global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 + global _flash_attn_version_required, _flash_attn_max_version#, _use_flash_attn_3 # get q/kv format qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_flash_attention_2 = use_flash_attention + use_flash_attention_3 = use_flash_attention + flash_attention_backend = None use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if not use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0") if not use_fused_attention: logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") if not use_unfused_attention: @@ -461,110 +467,124 @@ def get_attention_backend( # Filter: Compute capability if device_compute_capability < (8, 0): - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as it requires compute capability sm80+") - use_flash_attention = False + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 for compute capability < sm80") + use_flash_attention_2 = False if use_fused_attention: - logger.debug("Disabling FusedAttention as it requires compute capability sm80+") + logger.debug("Disabling FusedAttention for compute capability < sm80") use_fused_attention = False - if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") - _use_flash_attn_3 = False + if device_compute_capability != (9, 0): + if use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 for compute capability != sm90") + use_flash_attention_3 = False # Filter: Data type - if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ + if qkv_dtype not in [torch.bfloat16, torch.float16]: + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention 2 for unsupported qkv_dtype = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. ", + qkv_dtype, + ) + use_flash_attention_2 = False + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ torch.Tensor, Float8Tensor, ]: - if use_flash_attention and _flash_attn_is_installed: + if use_flash_attention_3 and _flash_attn_3_is_installed: logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " + "qkv_type = {torch.Tensor, Float8Tensor}. ", qkv_dtype, + qkv_type, ) - use_flash_attention = False + use_flash_attention_3 = False if use_fused_attention: logger.debug( - "Disabling FusedAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " + "qkv_type = {torch.Tensor, Float8Tensor}. ", qkv_dtype, + qkv_type, ) use_fused_attention = False # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and not _use_flash_attn_3: - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") - use_flash_attention = False - if use_flash_attention and _use_flash_attn_3 and is_training: - logger.debug( - "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" - ) - use_flash_attention = False + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 for FP8 attention") + use_flash_attention_2 = False + if use_flash_attention_3 and is_training: + if _flash_attn_3_is_installed: + logger.debug( + "Disabling FlashAttention 3 for FP8 training" + ) + use_flash_attention_3 = False if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False # Filter: KV cache - # backend | precision | KV cache | architecture | qkv_format - # -------------------------------------------------------------------------------------------- - # FusedAttention | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd - # | FP8 | non-paged | sm89+ | bshd,sbhd,thd - # FlashAttention v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd - # FlashAttention v3 | FP16/BF16 | non-paged/paged | sm80 | bshd,sbhd,thd - # | FP8 | non-paged/paged | sm80 | thd - # UnfusedDotProductAttention | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd + # backend | precision | KV cache | architecture | qkv_format | page_size + # --------------------------------------------------------------------------------------- + # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 + # | FP8 | non-paged | sm89+ | bshd,sbhd,thd | + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v3 | FP16/BF16 | non-paged/paged | sm80 | bshd,sbhd,thd | >= 1 + # | FP8 | non-paged/paged | sm80 | thd | >= 1 + # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: if context_parallel: logger.debug( - "Disabling all backends as KV caching is not supported for context parallelism" + "Disabling all backends for KV caching with context parallelism" ) use_flash_attention = False use_fused_attention = False use_unfused_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention: - use_flash_attention = False - logger.debug("Disabling FlashAttention for FP8 KV caching") - if _use_flash_attn_3 and q_format != "thd": - _use_flash_attn_3 = False - logger.debug("Disabling FlashAttention 3 for FP8 KV caching in non-THD") + if use_flash_attention_3 and q_format != "thd": + if _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") + use_flash_attention_3 = False if use_fused_attention and inference_params.is_paged: - use_fused_attention = False logger.debug( - "Disabling FusedAttention for paged attention in FP8" + "Disabling FusedAttention for FP8 paged attention" ) + use_fused_attention = False if use_unfused_attention: - use_unfused_attention = False logger.debug("Disabling UnfusedAttention for FP8 attention") + use_unfused_attention = False else: if use_fused_attention and pad_between_seqs: use_fused_attention = False logger.debug("Disabling FusedAttention for pad_between_seqs = True and KV caching") - if use_flash_attention and pad_between_seqs: + if q_format == "thd" and pad_between_seqs: + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention 2 for pad_between_seqs = True and KV caching" + ) + if use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug( + "Disabling FlashAttention 3 for pad_between_seqs = True and KV caching" + ) use_flash_attention = False - logger.debug("Disabling FlashAttention for pad_between_seqs = True and KV caching") if inference_params.is_paged: - if use_fused_attention and cudnn_version < (9, 5, 0): - logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+") - use_fused_attention = False - if use_flash_attention and not _use_flash_attn_3 and not _flash_attn_2_5_plus: + if use_flash_attention_2 and not _flash_attn_2_5_plus: logger.debug( - "Disabling FlashAttention as paged attention requires flash-attn 2.5+, or 3.0" - " (Hopper only)" + "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" ) - use_flash_attention = False + use_flash_attention_2 = False # Filter: Head dimension - if use_flash_attention and head_dim_qk != head_dim_v: - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False - if use_flash_attention and ( + if head_dim_qk != head_dim_v: + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + use_flash_attention_2 = False + if use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 as it does not support MLA.") + use_flash_attention_3 = False + if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 or ( @@ -574,7 +594,7 @@ def get_attention_backend( ): if _flash_attn_is_installed: logger.debug( - "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " "head_dim_qk <= 256 (>192 requires sm80/90/100+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", @@ -582,7 +602,13 @@ def get_attention_backend( head_dim_v, ".".join([str(i) for i in device_compute_capability]), ) - use_flash_attention = False + use_flash_attention_2 = False + if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): + if _flash_attn_3_is_installed: + logger.debug( + "Disabling FlashAttention 3 due to head_dim > 128" + ) + use_flash_attention_3 = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": logger.debug( @@ -596,8 +622,8 @@ def get_attention_backend( if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") use_unfused_attention = False - if use_flash_attention and pad_between_seqs: - if _flash_attn_is_installed: + if pad_between_seqs: + if (use_flash_attention_2 and _flash_attn_is_installed) or (use_flash_attention_3 and _flash_attn_is_installed): logger.debug( "Disabling FlashAttention for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" @@ -605,9 +631,9 @@ def get_attention_backend( use_flash_attention = False # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: + if attention_dropout != 0.0 and use_flash_attention_3: logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False + use_flash_attention_3 = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -626,29 +652,29 @@ def get_attention_backend( "Disabling UnfusedDotProductAttention as it does not support context parallelism" ) use_unfused_attention = False - if context_parallel and use_flash_attention: + if context_parallel and (use_flash_attention_2 or use_flash_attention_3): if fp8 and fp8_meta["recipe"].fp8_dpa: - if _flash_attn_is_installed: + if _flash_attn_is_installed or _flash_attn_3_is_installed: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) use_flash_attention = False if "bottom_right" in attn_mask_type: - if _flash_attn_is_installed: + if _flash_attn_is_installed or _flash_attn_3_is_installed: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal_bottom_right masking" ) use_flash_attention = False elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - if _flash_attn_is_installed: + if _flash_attn_is_installed or _flash_attn_3_is_installed: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal masking for cross-attention" ) use_flash_attention = False elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - if _flash_attn_is_installed: + if _flash_attn_is_installed or _flash_attn_3_is_installed: logger.debug( "Disabling FlashAttention as it does not support context parallelism with bias" " type of %s", @@ -656,7 +682,7 @@ def get_attention_backend( ) use_flash_attention = False elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - if _flash_attn_is_installed: + if _flash_attn_is_installed or _flash_attn_3_is_installed: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " attention bias for THD format" @@ -714,52 +740,37 @@ def get_attention_backend( # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention for arbitrary mask") - use_flash_attention = False + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 for arbitrary mask") + use_flash_attention_2 = False + if use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 for arbitrary mask") + use_flash_attention_3 = False if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False if ( - use_flash_attention - and _use_flash_attn_3 + use_flash_attention_3 and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): logger.warning( "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + "causal mask." ) - _use_flash_attn_3 = False + use_flash_attention_3 = False if ( - use_flash_attention + use_flash_attention_2 and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): if _flash_attn_2_1_plus: logger.warning( - "Disabling FlashAttention as it only supports bottom-right-diagonal " + "Disabling FlashAttention 2 as it only supports bottom-right-diagonal " "causal mask since flash-attn 2.1. See " "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) - use_flash_attention = False - if not _flash_attn_is_installed: - _flash_attn_max_version = PkgVersion("2.1") - if ( - use_flash_attention - and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] - and max_seqlen_q != max_seqlen_kv - ): - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.1") - elif not _flash_attn_2_1_plus and not _use_flash_attn_3: - logger.warning( - "Disabling FlashAttention as it only supports top-left-diagonal " - "causal mask before flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False + use_flash_attention_2 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -790,19 +801,14 @@ def get_attention_backend( "with s_q > s_kv for cross-attention" ) use_fused_attention = False - if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if _use_flash_attn_3: - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - _use_flash_attn_3 = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if not _flash_attn_is_installed: _flash_attn_version_required = PkgVersion("2.3") elif not _flash_attn_2_3_plus: logger.debug( "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" ) - use_flash_attention = False + use_flash_attention_2 = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -813,23 +819,28 @@ def get_attention_backend( # | | bottom_right (converts to a 'post_scale_bias' bias) # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias - if use_flash_attention and core_attention_bias_type == "alibi": - if _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for ALiBi") - _use_flash_attn_3 = False - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.4") - elif not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") - use_flash_attention = False + if core_attention_bias_type == "alibi": + if use_flash_attention_3: + if _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 for ALiBi") + use_flash_attention_3 = False + if use_flash_attention_2: + if not _flash_attn_is_installed: + _flash_attn_version_required = PkgVersion("2.4") + elif not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention_2 = False - if use_flash_attention and ( + if ( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None ): - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention for pre/post_scale_bias") - use_flash_attention = False + if use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 for pre/post_scale_bias") + use_flash_attention_2 = False + if use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 for pre/post_scale_bias") + use_flash_attention_3 = False fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_shape = core_attention_bias_shape @@ -932,16 +943,16 @@ def get_attention_backend( # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes - if use_flash_attention and deterministic: + if use_flash_attention_2 and deterministic: if not _flash_attn_is_installed: _flash_attn_version_required = PkgVersion("2.4.1") - elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3: + elif not _flash_attn_2_4_1_plus: logger.warning( "Disabling FlashAttention as version <2.4.1 does not support deterministic " "execution. To use FlashAttention with deterministic behavior, " "please install flash-attn >= 2.4.1." ) - use_flash_attention = False + use_flash_attention_2 = False if use_fused_attention and deterministic: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons") @@ -958,24 +969,41 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for determinism reasons") use_fused_attention = False - # All available backends - available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + # use_flash_attention may have been used for both FAv2 and FAv3 above + use_flash_attention_2 = use_flash_attention and use_flash_attention_2 + use_flash_attention_3 = use_flash_attention and use_flash_attention_3 # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. # When `FusedAttention` does not support the provided attention params, and `FlashAttention` # does, we recommend users to install flash-attn if not installed already. - if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed: - logger.warning( - "flash-attn may provide important feature support or performance improvement." - " Please install flash-attn %s.", - _get_supported_versions( - _flash_attn_version_required, - _flash_attn_max_version, - ), - ) - if use_flash_attention and not _flash_attn_is_installed and not _flash_attn_3_is_installed: - use_flash_attention = False - available_backends[0] = False + if not use_fused_attention: + if use_flash_attention_3 and not _flash_attn_3_is_installed: + logger.warning( + "flash-attn v3 may provide important feature support or performance improvement." + " Please install flash-attn v3 by \n%s", + _flash_attn_3_installation_steps, + ) + elif use_flash_attention_2 and not _flash_attn_is_installed: + logger.warning( + "flash-attn may provide important feature support or performance improvement." + " Please install flash-attn %s.", + _get_supported_versions( + _flash_attn_version_required, + _flash_attn_max_version, + ), + ) + # All available backends + if use_flash_attention_2 and not _flash_attn_is_installed: + use_flash_attention_2 = False + if use_flash_attention_3 and not _flash_attn_3_is_installed: + use_flash_attention_3 = False + use_flash_attention = use_flash_attention_2 or use_flash_attention_3 + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + # use FAv3 when both are present + if use_flash_attention_2: + flash_attention_backend = _flash_attn_version + if use_flash_attention_3: + flash_attention_backend = _flash_attn_3_version logger.debug( "Available backends = {FlashAttention=%s, FusedAttention=%s%s," @@ -994,25 +1022,13 @@ def get_attention_backend( if ( use_flash_attention and use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - ): - if device_compute_capability >= (9, 0): + and device_compute_capability >= (9, 0) + ): logger.debug( "Disabling FlashAttention to give FusedAttention preference on Hopper+ " "for performance reasons" ) use_flash_attention = False - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["FP8"] - and _use_flash_attn_3 - ): - logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " - "in FP8 execution" - ) - use_flash_attention = False # Selected backend if use_flash_attention: @@ -1031,6 +1047,7 @@ def get_attention_backend( global _attention_backends _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend _attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["use_unfused_attention"] = use_unfused_attention @@ -1038,6 +1055,7 @@ def get_attention_backend( return ( use_flash_attention, + flash_attention_backend, use_fused_attention, fused_attention_backend, use_unfused_attention, @@ -1886,6 +1904,7 @@ def forward( cp_global_ranks, cp_stream, quantizers, + use_fa_v3, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") @@ -2043,12 +2062,12 @@ def forward( if use_fused_attention: softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) else: - softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3 + softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or use_fa_v3 flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: + if use_fa_v3: #if qkv_format == "thd": # flash_attn_fwd = _flash_attn_varlen_fwd_v3 #else: @@ -2062,7 +2081,7 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3: + if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or use_fa_v3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 @@ -2242,12 +2261,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[3] elif i <= rank: if pad_between_seqs_q: @@ -2352,7 +2371,7 @@ def forward( max_seqlen_q, max_seqlen_kv // 2, ] - if _use_flash_attn_3 or ( + if use_fa_v3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) @@ -2378,12 +2397,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: @@ -2497,7 +2516,7 @@ def forward( max_seqlen_q // 2, max_seqlen_kv, ] - if _use_flash_attn_3 or ( + if use_fa_v3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) @@ -2523,12 +2542,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: @@ -2637,12 +2656,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[3] if i > 0: @@ -2831,6 +2850,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.use_fa_v3 = use_fa_v3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") return out_ret @@ -2839,6 +2859,7 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + use_fa_v3 = ctx.use_fa_v3 cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a @@ -3002,7 +3023,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: + if use_fa_v3: if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_varlen_bwd_v3 else: @@ -3152,12 +3173,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, 0) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 - if not _use_flash_attn_3: + if not use_fa_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3267,12 +3288,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) if _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not _use_flash_attn_3: + if not use_fa_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3384,12 +3405,12 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not _use_flash_attn_3: + if not use_fa_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3478,12 +3499,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not _use_flash_attn_3: + if not use_fa_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout, @@ -3791,6 +3812,7 @@ def forward( window_size, cp_group, cp_stream, + use_fa_v3, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -3816,7 +3838,7 @@ def forward( flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: + if use_fa_v3: #if qkv_format == "thd": # flash_attn_fwd = _flash_attn_varlen_fwd_v3 #else: @@ -3943,7 +3965,7 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -3959,12 +3981,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not use_fa_v3: rng_states[i] = fa_outputs[3] if i > 0: @@ -4009,6 +4031,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.use_fa_v3 = use_fa_v3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") return out @@ -4016,6 +4039,7 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") + use_fa_v3 = ctx.use_fa_v3 cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) @@ -4064,7 +4088,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: + if use_fa_v3: if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_varlen_bwd_v3 else: @@ -4137,7 +4161,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, max_seqlen_kv, ] - if not _use_flash_attn_3: + if not use_fa_v3: fa_backward_kwargs["rng_state"] = rng_states[i] if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] @@ -4257,6 +4281,7 @@ def forward( cp_group, cp_stream, quantizers, + use_fa_v3, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -4281,7 +4306,7 @@ def forward( flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: + if use_fa_v3: #if qkv_format == "thd": # flash_attn_fwd = _flash_attn_varlen_fwd_v3 #else: @@ -4295,7 +4320,7 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size[0] @@ -4418,10 +4443,10 @@ def forward( ) if not _flash_attn_2_7_0_plus: out, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + rng_state = fa_outputs[7] if not use_fa_v3 else None else: out, softmax_lse = fa_outputs[0], fa_outputs[1] - rng_state = fa_outputs[3] if not _use_flash_attn_3 else None + rng_state = fa_outputs[3] if not use_fa_v3 else None aux_ctx_tensors = [softmax_lse, rng_state] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) @@ -4504,6 +4529,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.use_fa_v3 = use_fa_v3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") return out_ret @@ -4511,6 +4537,7 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") + use_fa_v3 = ctx.use_fa_v3 cp_size = get_distributed_world_size(ctx.cp_group) ( @@ -4592,7 +4619,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: + if use_fa_v3: if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_varlen_bwd_v3 else: @@ -4605,7 +4632,7 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = ctx.window_size elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = ctx.window_size[0] @@ -4681,7 +4708,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if not _use_flash_attn_3: + if not use_fa_v3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( dout, @@ -4778,6 +4805,7 @@ def attn_forward_func_with_cp( fp8=False, fp8_meta=None, quantizers=None, + use_fa_v3=False, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -4845,15 +4873,15 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers] + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, use_fa_v3] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream] + args += [window_size, cp_group, cp_stream, use_fa_v3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers] + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_fa_v3] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -5807,6 +5835,7 @@ def forward( fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, inference_params: Optional[InferenceParams] = None, + flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), ) -> torch.Tensor: """flash-attn fprop""" @@ -5973,6 +6002,9 @@ def forward( batch_size * context_len, ) + use_fa_v3 = False + if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): + use_fa_v3 = True if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -6002,6 +6034,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, + use_fa_v3=use_fa_v3, ) else: @@ -6031,20 +6064,20 @@ def forward( qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type ): - func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 + func = flash_attn_func if not use_fa_v3 else flash_attn_func_v3 else: - if not _use_flash_attn_3: + if not use_fa_v3: func = flash_attn_varlen_func elif inference_params is None: func = flash_attn_varlen_func_v3 else: func = flash_attn_with_kvcache_v3 - if not _use_flash_attn_3 or inference_params is None: + if not use_fa_v3 or inference_params is None: fa_optional_forward_args_thd.append(cu_seqlens_q) fa_optional_forward_args_thd.append(cu_seqlens_kv) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) - if not _use_flash_attn_3: + if not use_fa_v3: fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: fa_optional_forward_kwargs["window_size"] = window_size @@ -6919,6 +6952,35 @@ def forward( ) else: with self.attention_dropout_ctx(): + print( + f"{max_seqlen_q=}", + f"{max_seqlen_kv=}", + f"{cu_seqlens_q=}", + f"{cu_seqlens_kv=}", + f"{cu_seqlens_q_padded=}", + f"{cu_seqlens_kv_padded=}", + f"{page_table_k=}", + f"{page_table_v=}", + f"{query_layer.shape}", + f"{key_layer.shape}", + f"{value_layer.shape}", + f"{qkv_dtype=}", + f"{core_attention_bias=}", + f"{self.softmax_scale=}", + f"{self.attention_dropout if self.training else 0.0=}", + f"{fast_zero_fill=}", + f"{qkv_layout=}", + f"{core_attention_bias_type=}", + f"{attn_mask_type=}", + f"{window_size=}", + f"{None=}", # rng_gen + f"{fused_attention_backend=}", + f"{use_FAv2_bwd=}", + f"{fp8=}", + f"{fp8_meta=}", + f"{quantizers=}", + f"{self.deterministic=}", + ) output = FusedAttnFunc.apply( self.training, max_seqlen_q, @@ -7655,6 +7717,15 @@ def forward( for x in [query_layer, key_layer, value_layer] ] + if query_layer.shape[0] == 2: + print('bbbbbbbbbbbbbbbbbb') + print('q', query_layer[0,0,0,:4]) + print('k', key_layer[0,0,0,:4]) + print('v', value_layer[0,0,0,:4]) + print('q', query_layer[1,0,0,:4]) + print('k', key_layer[1,0,0,:4]) + print('v', value_layer[1,0,0,:4]) + print('bbbbbbbbbbbbbbbbbb') ( key_layer, value_layer, @@ -7672,6 +7743,58 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None + if query_layer.shape[0] >= 7: + #print('q', query_layer[0,0,0,:4]) + #print('k', key_layer[0,0,0,:4]) + #print('v', value_layer[0,0,0,:4]) + #print('q', query_layer[1,6,0,:4]) + #print('k', key_layer[1,6,0,:4]) + #print('v', value_layer[1,6,0,:4]) + #print('xxxxxxx') + #print('q', query_layer[5,28,0,:4]) + #print('k', key_layer[5,28,0,:4]) + #print('v', value_layer[5,28,0,:4]) + #print('q', query_layer[6,15,0,:4]) + #print('k', key_layer[6,15,0,:4]) + #print('v', value_layer[6,15,0,:4]) + print('xxxxxxx') + #print('q', query_layer[5,26,0,:4]) + #print('k', key_layer[5,26,0,:4]) + #print('v', value_layer[5,26,0,:4]) + #print('q', query_layer[6,13,0,:4]) + #print('k', key_layer[6,13,0,:4]) + #print('v', value_layer[6,13,0,:4]) + #torch.save(query_layer, 'full_q.pt') + #torch.save(key_layer, 'full_k.pt') + #torch.save(value_layer, 'full_v.pt') + print('q', query_layer[5,35:37,0,:4]) + print('k', key_layer[5,35:37,0,:4]) + print('v', value_layer[5,35:37,0,:4]) + print('q', query_layer[6,22:24,0,:4]) + print('k', key_layer[6,22:24,0,:4]) + print('v', value_layer[6,22:24,0,:4]) + if query_layer.shape[0] == 2: + #torch.save(query_layer, 'partial_q.pt') + #torch.save(key_layer, 'partial_k.pt') + #torch.save(value_layer, 'partial_v.pt') + print('q', query_layer[0,0,0,:4]) + print('k', key_layer[0,36,0,:4]) + print('v', value_layer[0,36,0,:4]) + print('q', query_layer[1,0,0,:4]) + print('k', key_layer[1,23,0,:4]) + print('v', value_layer[1,23,0,:4]) + #print('q', query_layer[0,0,0,:4]) + #print('k', key_layer[0,26,0,:4]) + #print('v', value_layer[0,26,0,:4]) + #print('q', query_layer[1,0,0,:4]) + #print('k', key_layer[1,13,0,:4]) + #print('v', value_layer[1,13,0,:4]) + #print('q', query_layer[0,0,0,:4]) + #print('k', key_layer[0,35:37,0,:4]) + #print('v', value_layer[0,35:37,0,:4]) + #print('q', query_layer[1,0,0,:4]) + #print('k', key_layer[1,22:24,0,:4]) + #print('v', value_layer[1,22:24,0,:4]) # get accurate qkv_layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( @@ -7833,7 +7956,7 @@ def forward( fp8_meta=self.fp8_meta, inference_params=inference_params, ) - global _attention_backends, _use_flash_attn_3 + global _attention_backends#, _use_flash_attn_3 if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -7841,9 +7964,10 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: - _use_flash_attn_3 = _flash_attn_3_is_installed + #_use_flash_attn_3 = _flash_attn_3_is_installed ( use_flash_attention, + flash_attention_backend, use_fused_attention, fused_attention_backend, use_unfused_attention, @@ -7852,7 +7976,7 @@ def forward( if use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", - _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version, + flash_attention_backend, ) elif use_fused_attention: self.logger.info( @@ -7863,6 +7987,7 @@ def forward( self.logger.info("Running with UnfusedDotProductAttention backend") else: use_flash_attention = _attention_backends["use_flash_attention"] + flash_attention_backend = _attention_backends["fused_attention_backend"] use_fused_attention = _attention_backends["use_fused_attention"] fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] @@ -7902,6 +8027,7 @@ def forward( fp8_meta=self.fp8_meta, quantizers=self.quantizers, inference_params=inference_params, + flash_attention_backend=flash_attention_backend, ) if use_fused_attention: From 38110a74f4ca2d923c56716fe4b2304c2278173d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 2 Mar 2025 16:20:29 -0800 Subject: [PATCH 149/239] WIP: tweaks to paged attn script Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 139 ++++++++++---------- 1 file changed, 68 insertions(+), 71 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 8e7a6d4b8a..ccf645efd7 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -14,7 +14,6 @@ from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe -import transformer_engine.pytorch.fp8 as fp8 from transformer_engine.pytorch import fp8_autocast, fp8_model_init from transformer_engine.pytorch.transformer import ( TransformerLayer, @@ -22,7 +21,7 @@ from transformer_engine.pytorch.attention import ( DotProductAttention, InferenceParams, - _use_flash_attn_3, + _flash_attn_3_is_installed, ) from transformer_engine.pytorch.utils import ( get_device_compute_capability, @@ -36,7 +35,6 @@ _get_attention_backends, ) from tests.pytorch.test_numerics import assert_allclose -fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() # Initialize RNG state seed = 1234 @@ -49,8 +47,6 @@ param_types = [torch.float16] if is_bf16_compatible(): param_types.append(torch.bfloat16) -if fp8_available: - param_types.append(torch.float8_e4m3fn) model_configs_infer = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -388,30 +384,83 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FlashAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) -def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph): +@pytest.mark.parametrize("is_fp8", [False, True]) +def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): logger = logging.getLogger("test_paged_attn") num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 config = model_configs_infer[model] + if backend == "FlashAttention" and _flash_attn_3_is_installed: + config_max_seqlen_q = config.max_seqlen_q + config_max_seqlen_kv = config.max_seqlen_kv + config.max_seqlen_q = 256 + config.max_seqlen_kv = 256 + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=is_fp8, + fp8_mha=is_fp8, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe + + # create a real-life simulation + max_batch_size = config.batch_size + page_size = None + total_num_pages = None + if is_paged: + page_size = 256 if backend == "FlashAttention" and _flash_attn_3_is_installed else 16 + config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) + total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) + else: + config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64) + sim = Simulation( + total_requests=config.total_requests, + max_seq_len=config.max_seqlen_kv, + max_ctx_len=config.max_ctx_len, + max_batch_size=max_batch_size, + poisson_rate=2, + ) + sim.print_setup(logger) - is_fp8 = dtype == torch.float8_e4m3fn - if is_fp8: - dtype = torch.bfloat16 + # initialize inference_params + inference_params = InferenceParams( + max_batch_size=max_batch_size, + max_seqlen_kv=config.max_seqlen_kv, + num_heads_kv=config.num_gqa_groups, + head_dim_k=config.head_dim_qk, + head_dim_v=config.head_dim_v, + dtype=dtype, + is_paged=is_paged, + page_size=page_size, + total_num_pages=total_num_pages, + num_heads_q=config.num_heads, + head_dim_q=config.head_dim_qk, + max_ctx_len=config.max_ctx_len, + qkv_format=qkv_format, + ) + for layer_number in range(1, num_layers + 1): + inference_params.allocate_memory(layer_number, qkv_format) # figure out supported backends inference_params_qkv_format = "bshd" + qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) if is_paged: - qkv_layout = "paged_kv_" + "_".join([inference_params_qkv_format] * 3) - else: - qkv_layout = "_".join([inference_params_qkv_format] * 3) - available_backends, fused_attn_backends = _get_attention_backends( + qkv_layout = "paged_kv_" + qkv_layout + available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=False, + is_training=False, + fp8=is_fp8, + fp8_meta=fp8_meta, + inference_params=inference_params, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if backend == "FlashAttention" and not flash_attn_supported: @@ -425,16 +474,11 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") - # flash-attn requires page size >= 256 - if backend == "FlashAttention" and not _use_flash_attn_3: - config_max_seqlen_q = config.max_seqlen_q - config_max_seqlen_kv = config.max_seqlen_kv - config.max_seqlen_q = 256 - config.max_seqlen_kv = 256 - if is_fp8 and (qkv_format != "thd" or module != "DotProductAttention"): + if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention" and dtype == torch.float16): pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") # create full model + logger.info("=== Generating all tokens at once ===") model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference") # generate data for all requests @@ -444,7 +488,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs") # generate reference results - logger.info("=== Generating all tokens at once ===") if module == "DotProductAttention": full_output = full_inputs for m in model: @@ -458,56 +501,10 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda full_output[0] if isinstance(full_output, List) else full_output, ) - # simulate real-life inference - logger.info("=== Generating one token at a time ===") - max_batch_size = config.batch_size - page_size = None - total_num_pages = None - if is_paged: - page_size = 256 if backend == "FlashAttention" and not _use_flash_attn_3 else 16 - config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) - total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) - else: - config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64) - sim = Simulation( - total_requests=config.total_requests, - max_seq_len=config.max_seqlen_kv, - max_ctx_len=config.max_ctx_len, - max_batch_size=max_batch_size, - poisson_rate=2, - ) - sim.print_setup(logger) - - # initialize inference_params - inference_params = InferenceParams( - max_batch_size=max_batch_size, - max_seqlen_kv=config.max_seqlen_kv, - num_heads_kv=config.num_gqa_groups, - head_dim_k=config.head_dim_qk, - head_dim_v=config.head_dim_v, - dtype=dtype, - is_paged=is_paged, - page_size=page_size, - total_num_pages=total_num_pages, - num_heads_q=config.num_heads, - head_dim_q=config.head_dim_qk, - max_ctx_len=config.max_ctx_len, - qkv_format=qkv_format, - ) - for layer_number in range(1, num_layers + 1): - inference_params.allocate_memory(layer_number, qkv_format) - # create inference model + logger.info("=== Generating one token at a time ===") model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="inference", fp8_dpa=is_fp8, fp8_mha=is_fp8) - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=is_fp8, - fp8_mha=is_fp8, - ) # graph the model if necessary if is_cuda_graph: t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") @@ -653,7 +650,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda incremental_output = incremental_output[0] # compare results - tol = get_tols(module, backend, dtype=dtype) + tol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": @@ -699,6 +696,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention" and not _use_flash_attn_3: + if backend == "FlashAttention" and _flash_attn_3_is_installed: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv From 07021c23cb61ecc52b457255587bc6234e1ab636 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Mar 2025 00:21:31 +0000 Subject: [PATCH 150/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_fused_attn.py | 4 +- tests/pytorch/fused_attn/test_paged_attn.py | 28 ++- .../common/fused_attn/fused_attn.cpp | 8 +- .../fused_attn_f16_arbitrary_seqlen.cu | 51 +--- .../common/fused_attn/fused_attn_fp8.cu | 50 +--- transformer_engine/common/fused_attn/utils.h | 19 +- transformer_engine/pytorch/attention.py | 223 +++++++++--------- transformer_engine/pytorch/csrc/extensions.h | 4 +- 8 files changed, 159 insertions(+), 228 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 6b333429cb..1ae3e3bd7f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -201,8 +201,8 @@ def test(): fp8_meta=fp8_meta, inference_params=inference_params, ) - _, _, flash_attention_backend, fused_attention_backend, _, available_backends = get_attention_backend( - attention_params + _, _, flash_attention_backend, fused_attention_backend, _, available_backends = ( + get_attention_backend(attention_params) ) return available_backends, flash_attention_backend, fused_attention_backend diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index ccf645efd7..e9f9b76245 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -384,7 +384,7 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True]) @@ -450,7 +450,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda inference_params_qkv_format = "bshd" qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) if is_paged: - qkv_layout = "paged_kv_" + qkv_layout + qkv_layout = "paged_kv_" + qkv_layout available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, @@ -474,7 +474,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") - if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention" and dtype == torch.float16): + if is_fp8 and not ( + qkv_format == "thd" and module == "DotProductAttention" and dtype == torch.float16 + ): pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") # create full model @@ -503,7 +505,17 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda # create inference model logger.info("=== Generating one token at a time ===") - model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="inference", fp8_dpa=is_fp8, fp8_mha=is_fp8) + model = get_model( + module, + config, + dtype, + backend, + qkv_format, + num_layers, + mode="inference", + fp8_dpa=is_fp8, + fp8_mha=is_fp8, + ) # graph the model if necessary if is_cuda_graph: @@ -639,7 +651,11 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): for m in model: incremental_output = m( - *incremental_output if isinstance(incremental_output, List) else incremental_output, + *( + incremental_output + if isinstance(incremental_output, List) + else incremental_output + ), cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, @@ -678,7 +694,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda rtol=tol, ) if qkv_format == "thd": - print('i ', i, seq, cu_seqlens_q) + print("i ", i, seq, cu_seqlens_q) print(full_output[seq, sim.t_total_lens[i] - 1, :4]) print(incremental_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b7620931ef..6f9e5f4eb3 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -273,13 +273,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0) || // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) (cudnn_runtime_version >= 90600 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e7a60f11ab..968fef5cb5 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -101,31 +101,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, + FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d_qk, d_v, //num_pages_k, //num_pages_v, //page_size_k, //page_size_v, - max_pages_per_seq_k, - max_pages_per_seq_v, - bias_b, - bias_h, - scaling_factor, - is_training, - dropout_probability, - layout, - bias_type, - mask_type, - window_size_left, - window_size_right, - true, - tensorType, + max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + scaling_factor, is_training, dropout_probability, layout, bias_type, + mask_type, window_size_left, window_size_right, true, tensorType, tensorType}; namespace fe = cudnn_frontend; @@ -535,32 +518,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, + FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d_qk, d_v, //0, //0, //0, //0, - 1, - 1, - bias_b, - bias_h, - scaling_factor, - true, - dropout_probability, - layout, - bias_type, - mask_type, - window_size_left, - window_size_right, - deterministic, - tensorType, - tensorType}; + 1, 1, bias_b, bias_h, scaling_factor, true, dropout_probability, + layout, bias_type, mask_type, window_size_left, window_size_right, + deterministic, tensorType, tensorType}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 88883a3ed0..9beadd0a2d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1672,32 +1672,14 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d, - d, + FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d, d, //0, //0, //0, //0, - 1, - 1, - bias_b, - bias_h, - scaling_factor, - is_training, - dropout_probability, - layout, - bias_type, - mask_type, - 0, - 0, - true, - fwd_tensor_type, - fwd_tensor_type}; + 1, 1, bias_b, bias_h, scaling_factor, is_training, + dropout_probability, layout, bias_type, mask_type, 0, 0, true, + fwd_tensor_type, fwd_tensor_type}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1976,31 +1958,13 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d, - d, + FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d, d, //0, //0, //0, //0, - 1, - 1, - bias_b, - bias_h, - scaling_factor, - true, - dropout_probability, - layout, - bias_type, - mask_type, - 0, - 0, - false, - fwd_tensor_type, + 1, 1, bias_b, bias_h, scaling_factor, true, dropout_probability, + layout, bias_type, mask_type, 0, 0, false, fwd_tensor_type, bwd_tensor_type}; namespace fe = cudnn_frontend; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 4766f80e34..8734bb3af1 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -115,17 +115,16 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, - //num_pages_k, num_pages_v, page_size_k, page_size_v, - max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, - window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + //num_pages_k, num_pages_v, page_size_k, page_size_v, + max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, + dropoutProbability, layout, mask_type, window_size_left, window_size_right, + deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, - //rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, - rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, - rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, - rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, - rhs.bwd_tensor_type); + //rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, + rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, + rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 813438ffc0..3cf3945adb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -194,7 +194,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False -#_use_flash_attn_3 = False +# _use_flash_attn_3 = False _flash_attn_3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install @@ -205,10 +205,10 @@ def _get_supported_versions(version_min, version_max): try: _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: - fa_logger.debug( - "flash-attn v3 is not installed. To use, please install it by \n%s", - _flash_attn_3_installation_steps, - ) + fa_logger.debug( + "flash-attn v3 is not installed. To use, please install it by \n%s", + _flash_attn_3_installation_steps, + ) else: from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_3.flash_attn_interface import ( @@ -219,17 +219,18 @@ def _get_supported_versions(version_min, version_max): ) from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - #from flash_attn_3.flash_attn_interface import ( + + # from flash_attn_3.flash_attn_interface import ( # _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, - #) - #from flash_attn_3.flash_attn_interface import ( + # ) + # from flash_attn_3.flash_attn_interface import ( # _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, - #) - + # ) + _flash_attn_3_is_installed = True _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") - #_use_flash_attn_3 = True - + # _use_flash_attn_3 = True + _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -444,7 +445,7 @@ def get_attention_backend( # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. - global _flash_attn_version_required, _flash_attn_max_version#, _use_flash_attn_3 + global _flash_attn_version_required, _flash_attn_max_version # , _use_flash_attn_3 # get q/kv format qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) @@ -517,9 +518,7 @@ def get_attention_backend( use_flash_attention_2 = False if use_flash_attention_3 and is_training: if _flash_attn_3_is_installed: - logger.debug( - "Disabling FlashAttention 3 for FP8 training" - ) + logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") @@ -536,9 +535,7 @@ def get_attention_backend( # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: if context_parallel: - logger.debug( - "Disabling all backends for KV caching with context parallelism" - ) + logger.debug("Disabling all backends for KV caching with context parallelism") use_flash_attention = False use_fused_attention = False use_unfused_attention = False @@ -548,9 +545,7 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") use_flash_attention_3 = False if use_fused_attention and inference_params.is_paged: - logger.debug( - "Disabling FusedAttention for FP8 paged attention" - ) + logger.debug("Disabling FusedAttention for FP8 paged attention") use_fused_attention = False if use_unfused_attention: logger.debug("Disabling UnfusedAttention for FP8 attention") @@ -605,9 +600,7 @@ def get_attention_backend( use_flash_attention_2 = False if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): if _flash_attn_3_is_installed: - logger.debug( - "Disabling FlashAttention 3 due to head_dim > 128" - ) + logger.debug("Disabling FlashAttention 3 due to head_dim > 128") use_flash_attention_3 = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": @@ -623,7 +616,9 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") use_unfused_attention = False if pad_between_seqs: - if (use_flash_attention_2 and _flash_attn_is_installed) or (use_flash_attention_3 and _flash_attn_is_installed): + if (use_flash_attention_2 and _flash_attn_is_installed) or ( + use_flash_attention_3 and _flash_attn_is_installed + ): logger.debug( "Disabling FlashAttention for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" @@ -755,8 +750,7 @@ def get_attention_backend( and max_seqlen_q != max_seqlen_kv ): logger.warning( - "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " - "causal mask." + "Disabling FlashAttention 3 as it only supports bottom-right-diagonal causal mask." ) use_flash_attention_3 = False if ( @@ -1019,16 +1013,12 @@ def get_attention_backend( ) # Select FusedAttention for performance - if ( - use_flash_attention - and use_fused_attention - and device_compute_capability >= (9, 0) - ): - logger.debug( - "Disabling FlashAttention to give FusedAttention preference on Hopper+ " - "for performance reasons" - ) - use_flash_attention = False + if use_flash_attention and use_fused_attention and device_compute_capability >= (9, 0): + logger.debug( + "Disabling FlashAttention to give FusedAttention preference on Hopper+ " + "for performance reasons" + ) + use_flash_attention = False # Selected backend if use_flash_attention: @@ -1065,12 +1055,12 @@ def get_attention_backend( @torch.no_grad() def get_padding_mask( - batch_size: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - max_seqlen_q: int, - max_seqlen_kv: int, - ): + batch_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_q: int, + max_seqlen_kv: int, +): """Convert cu_seqlens to attention_mask""" seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] @@ -2068,9 +2058,9 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_fa_v3: - #if qkv_format == "thd": + # if qkv_format == "thd": # flash_attn_fwd = _flash_attn_varlen_fwd_v3 - #else: + # else: # flash_attn_fwd = _flash_attn_fwd_v3 flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) @@ -3839,9 +3829,9 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_fa_v3: - #if qkv_format == "thd": + # if qkv_format == "thd": # flash_attn_fwd = _flash_attn_varlen_fwd_v3 - #else: + # else: # flash_attn_fwd = _flash_attn_fwd_v3 flash_attn_fwd = _flash_attn_fwd_v3 else: @@ -4307,9 +4297,9 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_fa_v3: - #if qkv_format == "thd": + # if qkv_format == "thd": # flash_attn_fwd = _flash_attn_varlen_fwd_v3 - #else: + # else: # flash_attn_fwd = _flash_attn_fwd_v3 flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = window_size @@ -5299,9 +5289,7 @@ def forward( qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) if inference_params is not None and inference_params.is_paged: - key_layer, value_layer = inference_params.convert_paged_to_nonpaged( - self.layer_number - ) + key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now @@ -6060,10 +6048,7 @@ def forward( # | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e. # | | bshd/sbhd/thd + padding fa_optional_forward_args_thd = [] - if ( - qkv_format in ["bshd", "sbhd"] - and "padding" not in attn_mask_type - ): + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: func = flash_attn_func if not use_fa_v3 else flash_attn_func_v3 else: if not use_fa_v3: @@ -6115,7 +6100,9 @@ def forward( cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens if inference_params.is_paged: - fa_3_optional_forward_kwargs["page_table"] = inference_params.cache_manager.page_table[:batch_size] + fa_3_optional_forward_kwargs["page_table"] = ( + inference_params.cache_manager.page_table[:batch_size] + ) if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -7718,14 +7705,14 @@ def forward( ] if query_layer.shape[0] == 2: - print('bbbbbbbbbbbbbbbbbb') - print('q', query_layer[0,0,0,:4]) - print('k', key_layer[0,0,0,:4]) - print('v', value_layer[0,0,0,:4]) - print('q', query_layer[1,0,0,:4]) - print('k', key_layer[1,0,0,:4]) - print('v', value_layer[1,0,0,:4]) - print('bbbbbbbbbbbbbbbbbb') + print("bbbbbbbbbbbbbbbbbb") + print("q", query_layer[0, 0, 0, :4]) + print("k", key_layer[0, 0, 0, :4]) + print("v", value_layer[0, 0, 0, :4]) + print("q", query_layer[1, 0, 0, :4]) + print("k", key_layer[1, 0, 0, :4]) + print("v", value_layer[1, 0, 0, :4]) + print("bbbbbbbbbbbbbbbbbb") ( key_layer, value_layer, @@ -7744,57 +7731,57 @@ def forward( cu_seqlens_kv_padded = None if query_layer.shape[0] >= 7: - #print('q', query_layer[0,0,0,:4]) - #print('k', key_layer[0,0,0,:4]) - #print('v', value_layer[0,0,0,:4]) - #print('q', query_layer[1,6,0,:4]) - #print('k', key_layer[1,6,0,:4]) - #print('v', value_layer[1,6,0,:4]) - #print('xxxxxxx') - #print('q', query_layer[5,28,0,:4]) - #print('k', key_layer[5,28,0,:4]) - #print('v', value_layer[5,28,0,:4]) - #print('q', query_layer[6,15,0,:4]) - #print('k', key_layer[6,15,0,:4]) - #print('v', value_layer[6,15,0,:4]) - print('xxxxxxx') - #print('q', query_layer[5,26,0,:4]) - #print('k', key_layer[5,26,0,:4]) - #print('v', value_layer[5,26,0,:4]) - #print('q', query_layer[6,13,0,:4]) - #print('k', key_layer[6,13,0,:4]) - #print('v', value_layer[6,13,0,:4]) - #torch.save(query_layer, 'full_q.pt') - #torch.save(key_layer, 'full_k.pt') - #torch.save(value_layer, 'full_v.pt') - print('q', query_layer[5,35:37,0,:4]) - print('k', key_layer[5,35:37,0,:4]) - print('v', value_layer[5,35:37,0,:4]) - print('q', query_layer[6,22:24,0,:4]) - print('k', key_layer[6,22:24,0,:4]) - print('v', value_layer[6,22:24,0,:4]) + # print('q', query_layer[0,0,0,:4]) + # print('k', key_layer[0,0,0,:4]) + # print('v', value_layer[0,0,0,:4]) + # print('q', query_layer[1,6,0,:4]) + # print('k', key_layer[1,6,0,:4]) + # print('v', value_layer[1,6,0,:4]) + # print('xxxxxxx') + # print('q', query_layer[5,28,0,:4]) + # print('k', key_layer[5,28,0,:4]) + # print('v', value_layer[5,28,0,:4]) + # print('q', query_layer[6,15,0,:4]) + # print('k', key_layer[6,15,0,:4]) + # print('v', value_layer[6,15,0,:4]) + print("xxxxxxx") + # print('q', query_layer[5,26,0,:4]) + # print('k', key_layer[5,26,0,:4]) + # print('v', value_layer[5,26,0,:4]) + # print('q', query_layer[6,13,0,:4]) + # print('k', key_layer[6,13,0,:4]) + # print('v', value_layer[6,13,0,:4]) + # torch.save(query_layer, 'full_q.pt') + # torch.save(key_layer, 'full_k.pt') + # torch.save(value_layer, 'full_v.pt') + print("q", query_layer[5, 35:37, 0, :4]) + print("k", key_layer[5, 35:37, 0, :4]) + print("v", value_layer[5, 35:37, 0, :4]) + print("q", query_layer[6, 22:24, 0, :4]) + print("k", key_layer[6, 22:24, 0, :4]) + print("v", value_layer[6, 22:24, 0, :4]) if query_layer.shape[0] == 2: - #torch.save(query_layer, 'partial_q.pt') - #torch.save(key_layer, 'partial_k.pt') - #torch.save(value_layer, 'partial_v.pt') - print('q', query_layer[0,0,0,:4]) - print('k', key_layer[0,36,0,:4]) - print('v', value_layer[0,36,0,:4]) - print('q', query_layer[1,0,0,:4]) - print('k', key_layer[1,23,0,:4]) - print('v', value_layer[1,23,0,:4]) - #print('q', query_layer[0,0,0,:4]) - #print('k', key_layer[0,26,0,:4]) - #print('v', value_layer[0,26,0,:4]) - #print('q', query_layer[1,0,0,:4]) - #print('k', key_layer[1,13,0,:4]) - #print('v', value_layer[1,13,0,:4]) - #print('q', query_layer[0,0,0,:4]) - #print('k', key_layer[0,35:37,0,:4]) - #print('v', value_layer[0,35:37,0,:4]) - #print('q', query_layer[1,0,0,:4]) - #print('k', key_layer[1,22:24,0,:4]) - #print('v', value_layer[1,22:24,0,:4]) + # torch.save(query_layer, 'partial_q.pt') + # torch.save(key_layer, 'partial_k.pt') + # torch.save(value_layer, 'partial_v.pt') + print("q", query_layer[0, 0, 0, :4]) + print("k", key_layer[0, 36, 0, :4]) + print("v", value_layer[0, 36, 0, :4]) + print("q", query_layer[1, 0, 0, :4]) + print("k", key_layer[1, 23, 0, :4]) + print("v", value_layer[1, 23, 0, :4]) + # print('q', query_layer[0,0,0,:4]) + # print('k', key_layer[0,26,0,:4]) + # print('v', value_layer[0,26,0,:4]) + # print('q', query_layer[1,0,0,:4]) + # print('k', key_layer[1,13,0,:4]) + # print('v', value_layer[1,13,0,:4]) + # print('q', query_layer[0,0,0,:4]) + # print('k', key_layer[0,35:37,0,:4]) + # print('v', value_layer[0,35:37,0,:4]) + # print('q', query_layer[1,0,0,:4]) + # print('k', key_layer[1,22:24,0,:4]) + # print('v', value_layer[1,22:24,0,:4]) # get accurate qkv_layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( @@ -7956,7 +7943,7 @@ def forward( fp8_meta=self.fp8_meta, inference_params=inference_params, ) - global _attention_backends#, _use_flash_attn_3 + global _attention_backends # , _use_flash_attn_3 if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -7964,7 +7951,7 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: - #_use_flash_attn_3 = _flash_attn_3_is_installed + # _use_flash_attn_3 = _flash_attn_3_is_installed ( use_flash_attention, flash_attention_backend, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ef0188fb36..94ab7f5088 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -74,8 +74,8 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, - torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, - bool is_non_paged); + torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged); /*************************************************************************************************** * GEMM From c5d6a069afbd460c6c4766e97cf21da92211997c Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Mon, 3 Mar 2025 11:31:20 +0800 Subject: [PATCH 151/239] [JAX] THD ring attention (#1454) * Support THD + ring attention for self attn Signed-off-by: Reese Wang * Consolidate reorder strategy Signed-off-by: Reese Wang * Fix dataclass frozen issue Signed-off-by: Reese Wang * Remove redundant code Signed-off-by: Reese Wang * Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extension Signed-off-by: Reese Wang * Fix lint Signed-off-by: Reese Wang * Refine P2P helper check_supported Signed-off-by: Reese Wang * Add segment_ids/pos check Signed-off-by: Reese Wang * Fixup Signed-off-by: Reese Wang * Add dual chunk swap example Signed-off-by: Reese Wang * Align different reorder code structure Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen --- tests/jax/test_distributed_fused_attn.py | 60 +- tests/jax/test_fused_attn.py | 89 ++- transformer_engine/jax/attention.py | 123 ++- .../jax/cpp_extensions/attention.py | 714 ++++++++++++------ .../jax/csrc/extensions/attention.cpp | 4 + 5 files changed, 693 insertions(+), 297 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 898993f5d1..2abcb28dec 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -23,6 +23,7 @@ reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, CPStrategy, + ReorderStrategy, ) @@ -210,29 +211,29 @@ def test_cross_attn( "data_shape", [ # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. - pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"), + pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ], ) -@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) +@pytest.mark.parametrize("kv_groups", [1, 8]) +@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( - "attn_mask_type", + "qkv_layout, attn_mask_type", [ - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), - pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), - ], -) -@pytest.mark.parametrize("dtype", [jnp.bfloat16]) -@pytest.mark.parametrize( - "qkv_layout", - [ - pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), + pytest.param( + QKVLayout.THD_THD_THD, + AttnMaskType.PADDING_CAUSAL_MASK, + id="THD_SEPARATE-PADDING_CAUSAL", + ), ], ) @pytest.mark.parametrize( "load_balanced", - [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], ) class TestDistributedContextParallelSelfAttn: @@ -265,7 +266,6 @@ def impl_test_context_parallel_attn( data_shape = batch, seqlen, num_head, hidden num_kv_heads = num_head // kv_groups - scaling_factor = 1.0 / np.sqrt(num_head) runner = FusedAttnRunner( batch, @@ -282,7 +282,7 @@ def impl_test_context_parallel_attn( qkv_layout, bias_shape, None, - SeqDescFormat.Seqlens, + SeqDescFormat.SegmentIDs, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, @@ -297,7 +297,7 @@ def check_has_backend_for_mask(mask_type): dtype, qkv_layout, attn_bias_type, - attn_mask_type, + mask_type, dropout_prob, num_head, num_kv_heads, @@ -340,6 +340,8 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, ): + if qkv_layout.is_thd(): + pytest.skip("THD doesn't support all gather context parallelism.") return self.impl_test_context_parallel_attn( device_count, mesh_shape, @@ -377,7 +379,10 @@ def test_context_parallel_ring_attn( else: os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" - self.impl_test_context_parallel_attn( + if qkv_layout.is_thd() and not load_balanced: + pytest.skip("THD + ring doesn't support unbalanced context parallelism.") + + return self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -404,17 +409,26 @@ class TestReorderCausalLoadBalancing: ], ) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) - def test(self, cp_size, shape, qkv_format): + @pytest.mark.parametrize( + "reorder_strategy", + [ + pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"), + pytest.param(ReorderStrategy.Striped, id="Striped"), + ], + ) + def test(self, cp_size, shape, qkv_format, reorder_strategy): tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) + seq_dim = 1 if qkv_format == QKVFormat.SBHD: tensor = tensor.swapaxes(0, 1) + seq_dim = 0 ref = tensor.copy() - reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2]) - inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2]) + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3]) - reordered = reorder(tensor, cp_size, qkv_format) - inversed = inverse(reordered, cp_size, qkv_format) + reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim) + inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim) assert jnp.array_equal(inversed, ref) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ff4139ee51..037e364a7e 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -28,12 +28,14 @@ AttnBiasType, AttnMaskType, QKVLayout, + QKVFormat, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, make_swa_mask, SequenceDescriptor, CPStrategy, + ReorderStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.transformer_engine_jax import ( @@ -347,9 +349,9 @@ def _check_configs(self): self.backend = FusedAttnHelper( self.dtype, self.dtype, - self.qkv_layout.value, - self.attn_bias_type.value, - self.attn_mask_type.value, + self.qkv_layout, + self.attn_bias_type, + self.attn_mask_type, self.dropout_prob, self.num_heads_q, self.num_heads_kv, @@ -500,7 +502,8 @@ def generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) - if self.qkv_layout == QKVLayout.T3HD: + # TODO(rewang): record only self attention and find the reason of cross attention + if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: self.segment_ids_kv = self.segment_ids_q self.segment_pos_kv = self.segment_pos_q self.pad_kv = self.pad_q @@ -536,6 +539,30 @@ def generate_random_segment_ids( self.window_size, ) + if self.cp_size > 1 and self.cp_load_balanced: + if self.qkv_layout.is_thd(): + reorder_strategy = ReorderStrategy.Striped + else: + reorder_strategy = ReorderStrategy.DualChunkSwap + + seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1 + self.cp_reorder_fn = partial( + reorder_causal_load_balancing, + strategy=reorder_strategy, + cp_size=self.cp_size, + seq_dim=seq_dim, + ) + self.cp_inverse_reorder_fn = partial( + inverse_reorder_causal_load_balancing, + strategy=reorder_strategy, + cp_size=self.cp_size, + seq_dim=seq_dim, + ) + else: + # no-ops for non cp or non load balanced + self.cp_reorder_fn = lambda x: x + self.cp_inverse_reorder_fn = lambda x: x + # Test different input formats if self.qkv_layout.is_thd(): match self.seq_desc_format: @@ -548,8 +575,14 @@ def generate_random_segment_ids( ) case SeqDescFormat.SegmentIDs: self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( - (self.segment_ids_q, self.segment_ids_kv), - (self.segment_pos_q, self.segment_pos_kv), + ( + self.cp_reorder_fn(self.segment_ids_q), + self.cp_reorder_fn(self.segment_ids_kv), + ), + ( + self.cp_reorder_fn(self.segment_pos_q), + self.cp_reorder_fn(self.segment_pos_kv), + ), ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") @@ -605,7 +638,12 @@ def generate_random_segment_ids( case _: def to_dp_shardings(x): - pspec = PartitionSpec(self.mesh_resource.dp_resource) + if x.ndim == 1: + pspec = PartitionSpec(self.mesh_resource.dp_resource) + else: + pspec = PartitionSpec( + self.mesh_resource.dp_resource, self.mesh_resource.cp_resource + ) return NamedSharding(self.mesh, pspec) self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) @@ -637,24 +675,6 @@ def to_dp_shardings(x): self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) - # Softmax aux sharding - - if self.cp_size > 1 and self.cp_load_balanced: - self.cp_reorder_fn = partial( - reorder_causal_load_balancing, - cp_size=self.cp_size, - tensor_format=self.qkv_layout.get_qkv_format(), - ) - self.cp_inverse_reorder_fn = partial( - inverse_reorder_causal_load_balancing, - cp_size=self.cp_size, - tensor_format=self.qkv_layout.get_qkv_format(), - ) - else: - # no-ops for non cp or non load balanced - self.cp_reorder_fn = lambda x: x - self.cp_inverse_reorder_fn = lambda x: x - def test_forward(self): """ Test forward without JIT @@ -733,15 +753,24 @@ def test_backward(self): self._setup_inputs() - def grad_func(func, *args, **kwargs): + def grad_func(func, *args, cp_reverse_out=False, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient - ret_valid = jnp.where( - self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs) - ) + if not cp_reverse_out: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + func(*args, **kwargs), + ) + else: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + self.cp_inverse_reorder_fn(func(*args, **kwargs)), + ) return ( jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier ).astype(self.dtype) @@ -787,7 +816,7 @@ def grad_func(func, *args, **kwargs): jitted_primitive = jit( value_and_grad( lambda q, k, v, bias, *args: grad_func( - customcall_fused_dpa, q, k, v, bias, *args, **kwargs + customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs ), arg_nums, ), diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index a8245b533e..9b93faeb55 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -135,6 +135,39 @@ def is_thd(self): """ return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] + def to_qkvpacked(self): + """ + Return the corresponding qkvpacked format, useful when adjusting q, k, v layout + """ + qkv_format = self.get_qkv_format() + if qkv_format == QKVFormat.BSHD: + return QKVLayout.BS3HD + if qkv_format == QKVFormat.THD: + return QKVLayout.T3HD + raise ValueError(f"Unsupported {qkv_format=}") + + def to_kvpacked(self): + """ + Return the corresponding kvpacked format, useful when adjusting q, k, v layout + """ + qkv_format = self.get_qkv_format() + if qkv_format == QKVFormat.BSHD: + return QKVLayout.BSHD_BS2HD + if qkv_format == QKVFormat.THD: + return QKVLayout.THD_T2HD + raise ValueError(f"Unsupported {qkv_format=}") + + def to_separate(self): + """ + Return the corresponding separate format, useful when adjusting q, k, v layout + """ + qkv_format = self.get_qkv_format() + if qkv_format == QKVFormat.BSHD: + return QKVLayout.BSHD_BSHD_BSHD + if qkv_format == QKVFormat.THD: + return QKVLayout.THD_THD_THD + raise ValueError(f"Unsupported {qkv_format=}") + class CPStrategy(Enum): """Defines the context parallel strategies of Jax fused attention. @@ -149,6 +182,28 @@ class CPStrategy(Enum): RING = 2 +class ReorderStrategy(Enum): + """ + Defines the tokens re-order strategy for context parallel load balancing for causal mask. + + - DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between + GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the + mulitple of 2 * cp_size. + Examples: + - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15]; + - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3] + + - Striped: This strategy distributes the tokens in a striped (interleaved) manner across + the sequence. This is currently used for THD load balance. + Example: Consider 4 GPUs with seqlens=16. + - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15] + - After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15] + """ + + DualChunkSwap = 0 + Striped = 1 + + def make_swa_mask( segment_pos_q: jnp.ndarray, segment_pos_kv: jnp.ndarray, @@ -243,9 +298,9 @@ def make_helper(attn_mask_type): return tex.FusedAttnHelper( q_dtype, kv_dtype, - qkv_layout.value, - attn_bias_type.value, - attn_mask_type.value, + qkv_layout, + attn_bias_type, + attn_mask_type, dropout_probability, q_num_heads, kv_num_heads, @@ -276,16 +331,24 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen -def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): +def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int): """Reorders a tensor for load balancing the compute of causal attention.""" - seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 - return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False) + if strategy == ReorderStrategy.DualChunkSwap: + return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) + if strategy == ReorderStrategy.Striped: + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False) + raise ValueError(f"Unsupported {strategy=}") -def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): +def inverse_reorder_causal_load_balancing( + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int +): """Inverse operation of `reorder_causal_load_balancing`.""" - seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 - return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) + if strategy == ReorderStrategy.DualChunkSwap: + return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) + if strategy == ReorderStrategy.Striped: + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True) + raise ValueError(f"Unsupported {strategy=}") def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq): @@ -412,8 +475,6 @@ def get_seqlens_and_offsets( """ Acquire the seqlens/offsets for cuDNN backend """ - attn_mask_type = AttnMaskType(attn_mask_type) - qkv_layout = QKVLayout(qkv_layout) q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos assert q_segment_ids.shape == q_segment_pos.shape @@ -589,9 +650,9 @@ def _legacy_fused_attn( Intra-sequence padding is not valid. The padded tokens can only on the right-most. Otherwise the results will be wrong. seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. @@ -608,16 +669,18 @@ def _legacy_fused_attn( # Check inputs qkv match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD: + case QKVLayout.BS3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + case QKVLayout.BSHD_BS2HD: assert ( len(qkv) == 2 ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + case QKVLayout.BSHD_BSHD_BSHD: assert ( len(qkv) == 3 ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + case _: + raise ValueError(f"Unknown {qkv_layout=}") # convert the mask to seqlens, mask doesn't support ragged offsets if not attn_mask_type.is_padding(): @@ -689,16 +752,18 @@ def fused_attn_thd( # Check inputs qkv match qkv_layout: - case NVTE_QKV_Layout.NVTE_T3HD: + case QKVLayout.T3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_THD_T2HD: + case QKVLayout.THD_T2HD: assert ( len(qkv) == 2 ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_THD_THD_THD: + case QKVLayout.THD_THD_THD: assert ( len(qkv) == 3 ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + case _: + raise ValueError(f"Unknown {qkv_layout=}") batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) assert q_seq_lens.shape == (batch, q_max_seqlen) @@ -789,9 +854,9 @@ def _fused_attn_fwd_rule( bias, sequence_descriptor, seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - qkv_layout=qkv_layout.value, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training, @@ -845,9 +910,9 @@ def _fused_attn_bwd_rule( output, dz, sequence_descriptor, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - qkv_layout=qkv_layout.value, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training, @@ -903,9 +968,9 @@ def fused_attn( bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence. seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 51ff87ced1..409f08c7db 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for attention""" -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial, reduce import operator import os @@ -17,17 +17,18 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi -from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor +from transformer_engine.jax.attention import ( + AttnBiasType, + AttnMaskType, + QKVLayout, + QKVFormat, + CPStrategy, + SequenceDescriptor, +) from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import ( - NVTE_Bias_Type, - NVTE_Mask_Type, - NVTE_QKV_Layout, - NVTE_QKV_Format, - NVTE_Fused_Attn_Backend, - nvte_get_qkv_format, -) +from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend + from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import ( @@ -79,9 +80,9 @@ class _FusedAttnConfig: Passes static configuration properties of fused attention. """ - attn_bias_type: NVTE_Bias_Type - attn_mask_type: NVTE_Mask_Type - qkv_layout: NVTE_QKV_Layout + attn_bias_type: AttnBiasType + attn_mask_type: AttnMaskType + qkv_layout: QKVLayout scaling_factor: float dropout_probability: float is_training: bool @@ -99,9 +100,9 @@ class FusedAttnHelper: q_dtype: jnp.dtype kv_dtype: jnp.dtype - qkv_layout: NVTE_QKV_Layout - attn_bias_type: NVTE_Bias_Type - attn_mask_type: NVTE_Mask_Type + qkv_layout: QKVLayout + attn_bias_type: AttnBiasType + attn_mask_type: AttnMaskType dropout_probability: float q_num_heads: int kv_num_heads: int @@ -119,9 +120,9 @@ def get_fused_attn_backend(self): return transformer_engine_jax.get_fused_attn_backend( jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), - self.qkv_layout, - self.attn_bias_type, - self.attn_mask_type, + self.qkv_layout.value, + self.attn_bias_type.value, + self.attn_mask_type.value, self.dropout_probability, self.q_num_heads, self.kv_num_heads, @@ -140,24 +141,25 @@ def is_non_deterministic_allowed(): @staticmethod def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): """Parse qkv aval""" - match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape - kv_batch_shape = q_batch_shape - kv_max_seqlen = q_max_seqlen - num_gqa_groups = attn_heads - kv_head_dim = q_head_dim - assert nqkv == 3 - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape - assert nkv == 2 - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert k_aval.shape == v_aval.shape - case _: - raise ValueError(f"Unexpected {qkv_layout=}") + if qkv_layout.get_qkv_format() == QKVFormat.SBHD: + raise NotImplementedError + if qkv_layout.is_qkvpacked(): + *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape + kv_batch_shape = q_batch_shape + kv_max_seqlen = q_max_seqlen + num_gqa_groups = attn_heads + kv_head_dim = q_head_dim + assert nqkv == 3 + elif qkv_layout.is_kvpacked(): + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape + assert nkv == 2 + elif qkv_layout.is_separate(): + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape + assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}" + else: + raise ValueError(f"Unexpected {qkv_layout=}") assert q_batch_shape == kv_batch_shape assert q_head_dim == kv_head_dim assert q_aval.dtype == k_aval.dtype == v_aval.dtype @@ -310,7 +312,7 @@ def abstract( rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -330,9 +332,9 @@ def abstract( head_dim, config.scaling_factor, config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, jax_dtype_to_te_dtype(q_aval.dtype), config.is_training, config.max_segments_per_seq, @@ -385,7 +387,7 @@ def lowering( input_batch = reduce(operator.mul, batch_shape) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -419,9 +421,9 @@ def lowering( max_segments_per_seq=config.max_segments_per_seq, scaling_factor=float(config.scaling_factor), dropout_probability=float(config.dropout_probability), - bias_type=int(config.attn_bias_type), - mask_type=int(config.attn_mask_type), - qkv_layout=int(config.qkv_layout), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), is_training=config.is_training, deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=config.window_size[0], @@ -511,7 +513,7 @@ def impl( ) ) - if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -529,20 +531,11 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match config.qkv_layout: - case NVTE_QKV_Layout.NVTE_T3HD: - kv_max_seqlen = q_max_seqlen = q.shape[-4] - kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_T2HD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-4] - kv_batch = reduce(operator.mul, k.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_THD_THD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-3] - kv_batch = reduce(operator.mul, k.shape[:-3]) + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}" + kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 # cuDNN version < 9.3.0: @@ -610,29 +603,28 @@ def batcher(batched_args, batch_dims, *, config): def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): del result_infos q_spec = get_padded_spec(arg_infos[0]) - match config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - # q_spec = (...batch, q_seqlen, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) - ) - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - case _: - raise ValueError(f"Unsupported {config.qkv_layout=}") + if config.qkv_layout.is_qkvpacked(): + # q_spec = (...batch, q_seqlen, 3, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) + ) + elif config.qkv_layout.is_kvpacked(): + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + elif config.qkv_layout.is_separate(): + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + else: + raise ValueError(f"Unsupported {config.qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -705,7 +697,7 @@ def abstract( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -725,9 +717,9 @@ def abstract( head_dim, config.scaling_factor, config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, jax_dtype_to_te_dtype(q_aval.dtype), config.is_training, deterministic, @@ -787,7 +779,7 @@ def lowering( input_batch = reduce(operator.mul, batch_shape) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -824,9 +816,9 @@ def lowering( max_segments_per_seq=config.max_segments_per_seq, scaling_factor=float(config.scaling_factor), dropout_probability=float(config.dropout_probability), - bias_type=int(config.attn_bias_type), - mask_type=int(config.attn_mask_type), - qkv_layout=int(config.qkv_layout), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), is_training=config.is_training, deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=config.window_size[0], @@ -922,7 +914,7 @@ def impl( ) ) - if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -941,20 +933,11 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match config.qkv_layout: - case NVTE_QKV_Layout.NVTE_T3HD: - kv_max_seqlen = q_max_seqlen = q.shape[-4] - kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_T2HD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-4] - kv_batch = reduce(operator.mul, k.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_THD_THD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-3] - kv_batch = reduce(operator.mul, k.shape[:-3]) + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1 + kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 # cuDNN version < 9.3.0: @@ -1088,7 +1071,7 @@ def sharded_impl( config=config, ) global_dbias = local_dbias - if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not AttnBiasType.NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias @@ -1098,7 +1081,7 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) -def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): +def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): """Reorders a tensor for load balancing the compute of causal attention.""" if cp_size == 1: return tensor @@ -1108,7 +1091,7 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu # Need to ensure we have 2 pairs to swap for balancing between cp ranks if tensor.shape[seq_dim] % (cp_size * 2) != 0: - raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}") # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] @@ -1150,6 +1133,33 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu return combined.reshape(ori_tensor_shape) +def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool): + """Reorders a tensor for load balancing with striped pattern""" + origin_shape = tensor.shape + if origin_shape[seq_dim] % cp_size != 0: + raise ValueError( + "Expected origin_shape[seq_dim] is multiple of cp_size but got" + f" {origin_shape[seq_dim]=} and {cp_size=}" + ) + + if not is_inverse: + new_shape = [ + *origin_shape[:seq_dim], + *[origin_shape[seq_dim] // cp_size, cp_size], + *origin_shape[seq_dim + 1 :], + ] + else: + new_shape = [ + *origin_shape[:seq_dim], + *[cp_size, origin_shape[seq_dim] // cp_size], + *origin_shape[seq_dim + 1 :], + ] + + chunked_tensor = tensor.reshape(new_shape) + reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1) + return reordered_chunked_tensor.reshape(origin_shape) + + @dataclass(frozen=True) class _FusedAttnCPWithAllGatherHelper: """Helper class to assist with running the all-gather strategy for CP attention.""" @@ -1161,17 +1171,17 @@ def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused attention" - allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " @@ -1189,8 +1199,8 @@ def check_supported(self): def get_adjusted_mask(self): """Converts the mask for context parallelism.""" - if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: - return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: + return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type def get_step_config(self) -> _FusedAttnConfig: @@ -1217,14 +1227,13 @@ def ag(x): ) if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True) + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) return x - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return ag(k), v - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return ag(k), ag(v) + if self.config.qkv_layout.is_kvpacked(): + return ag(k), v + if self.config.qkv_layout.is_separate(): + return ag(k), ag(v) return k, v # fall through @@ -1234,7 +1243,7 @@ def reduce_scatter_dkv(self, dk, dv): def rs(x): if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False) + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) return lax_paral_op( x, @@ -1245,11 +1254,10 @@ def rs(x): tiled=True, ) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return rs(dk), dv - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return rs(dk), rs(dv) + if self.config.qkv_layout.is_kvpacked(): + return rs(dk), dv + if self.config.qkv_layout.is_separate(): + return rs(dk), rs(dv) return dk, dv # fall through @@ -1286,11 +1294,10 @@ def slice_kv(self, k, v, slice_seq_len): def sliced(x): return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return sliced(k), v - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return sliced(k), sliced(v) + if self.config.qkv_layout.is_kvpacked(): + return sliced(k), v + if self.config.qkv_layout.is_separate(): + return sliced(k), sliced(v) return k, v # fall through @@ -1300,13 +1307,12 @@ def pad_kv(self, dk, dv, pad_seq_len): def pad(x, npad): return jnp.pad(x, npad, "constant", constant_values=0.0) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] - return pad(dk, npad), dv - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] - return pad(dk, npad), pad(dv, npad) + if self.config.qkv_layout.is_kvpacked(): + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] + return pad(dk, npad), dv + if self.config.qkv_layout.is_separate(): + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] + return pad(dk, npad), pad(dv, npad) return dk, dv # fall through @@ -1378,7 +1384,7 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): results = [] for sub_idx in range(2): - if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked else: k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) @@ -1514,7 +1520,7 @@ def _cross_attn_bwd( results = [] for sub_idx in range(2): - if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked else: k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) @@ -1544,7 +1550,7 @@ def _cross_attn_bwd( ) # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. - if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK: + if config.attn_mask_type != AttnMaskType.NO_MASK: pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) @@ -1614,24 +1620,31 @@ def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused ring attention" - allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + if self.config.qkv_layout.is_thd(): + allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] + else: + allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] + if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + if self.config.qkv_layout.is_thd(): + allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK] + else: + allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) - if self.config.max_segments_per_seq != 1: + if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1: raise ValueError( f"{header} only supports max_segments_per_seq == 1 got:" f" {self.config.max_segments_per_seq}" @@ -1655,7 +1668,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=attn_mask_type, - qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + qkv_layout=QKVLayout.BSHD_BS2HD, scaling_factor=self.config.scaling_factor, dropout_probability=self.config.dropout_probability, is_training=self.config.is_training, @@ -1668,21 +1681,19 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: def stack_kv(self, k, v): """Stacks k and v tensors if not stacked.""" _not_used = jnp.zeros(0, dtype=k.dtype) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return k - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return jnp.stack([k, v], axis=2) + if self.config.qkv_layout.is_kvpacked(): + return k + if self.config.qkv_layout.is_separate(): + return jnp.stack([k, v], axis=2) return _not_used def unstack_kv(self, kv): """Un-stacks k and v tensors if not stacked.""" _not_used = jnp.zeros(0, dtype=kv.dtype) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return kv, _not_used - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return jnp.unstack(kv, axis=2) + if self.config.qkv_layout.is_kvpacked(): + return kv, _not_used + if self.config.qkv_layout.is_separate(): + return jnp.unstack(kv, axis=2) return _not_used, _not_used # fall through def permute_kv(self, kv, cp_perm): @@ -1803,8 +1814,8 @@ def mask_compute(attn_mask_type): ) return output_per_step, softmax_aux_per_step - causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) - no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) + causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK) + no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK) def half_kv_no_mask_compute(): q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) @@ -1824,7 +1835,7 @@ def half_kv_no_mask_compute(): _kv_segment_ids, _q_segment_pos, _kv_segment_pos, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + config=helper.get_step_config(AttnMaskType.NO_MASK), ) return output_per_step, softmax_aux_per_step @@ -1846,7 +1857,7 @@ def half_q_no_mask_compute(): _kv_segment_ids, _q_segment_pos, _kv_segment_pos, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + config=helper.get_step_config(AttnMaskType.NO_MASK), ) output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) softmax_aux_per_step = jnp.concat( @@ -1865,7 +1876,7 @@ def skip_compute(): ) return output_per_step, softmax_aux_per_step - if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + if config.attn_mask_type == AttnMaskType.CAUSAL_MASK: # This is for nested jax.lax.cond def jax_cond_wrap(): if config.context_parallel_load_balanced: @@ -2019,8 +2030,8 @@ def mask_compute(attn_mask_type): ) return dq_per_step, dk_dv_per_step, dbias_per_step - causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) - no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) + causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK) + no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK) def half_kv_no_mask_compute(): q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) @@ -2043,7 +2054,7 @@ def half_kv_no_mask_compute(): _kv_segment_ids, _q_segment_pos, _kv_segment_pos, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + config=helper.get_step_config(AttnMaskType.NO_MASK), ) dk_dv_per_step = jnp.concat( [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1 @@ -2081,7 +2092,7 @@ def half_q_no_mask_compute(): _kv_segment_ids, _q_segment_pos, _kv_segment_pos, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + config=helper.get_step_config(AttnMaskType.NO_MASK), ) dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) return dq_per_step, dk_dv_per_step, dbias_per_step @@ -2089,7 +2100,7 @@ def half_q_no_mask_compute(): def skip_compute(): return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias) - if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + if config.attn_mask_type == AttnMaskType.CAUSAL_MASK: # This is for nested jax.lax.cond def jax_cond_wrap(): if config.context_parallel_load_balanced: @@ -2107,7 +2118,7 @@ def jax_cond_wrap(): kv_next, dk_dv = jnp.unstack(kv_dk_dv) dq = dq + dq_per_step dk_dv = dk_dv + dk_dv_per_step - if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not AttnBiasType.NO_BIAS: dbias = dbias + dbias_per_step return (kv_next, dq, dk_dv, dbias) @@ -2124,7 +2135,7 @@ def jax_cond_wrap(): dk_dv = helper.permute_kv(dk_dv, cp_perm) global_dbias = dbias - if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not AttnBiasType.NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) dk, dv = helper.unstack_kv(dk_dv) @@ -2136,6 +2147,271 @@ def jax_cond_wrap(): register_primitive(FusedRingAttnBwdPrimitive) +class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Striped Ring Attention Forward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def fwd_impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + ): + if q_segment_ids.size == 0 or kv_segment_ids.size == 0: + raise ValueError("THD + ring attn only supports passing seqment_ids/pos") + + _not_used = jnp.zeros(0, dtype=v.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + if not config.qkv_layout.is_qkvpacked(): + subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked()) + else: + subblock_config = config + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + batch, q_max_seqlen, head, _ = q.shape + output = jnp.zeros(q.shape).astype(jnp.float32) + softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32) + + # RNG shape should be the shared shape. This is unused for ring attention as we do not + # support dropout currently. + rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) + + def scan_kv_block(idx, carry): + kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry + + # TODO(rewang): To check whether we need special handle for the last idx + # Send KV block to next step so we can overlap compute. + kv_next = helper.permute_kv(kv, cp_perm) + kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) + kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) + + output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + subblock_config, + ) + + # TODO(rewang): THD softmax_aux layout is acutally [B, S, H] + softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1)) + + def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step): + # No correction done here but we cast outputs to float32 and perform reduction + # in full precision. + return output_per_step.astype(jnp.float32), softmax_aux_per_step + + def correction(output, softmax_aux, output_per_step, softmax_aux_per_step): + new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * ( + output - output_per_step + ) + new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step) + return new_out, new_aux + + # first step there is no correction we get initial output and stats + output, softmax_aux = lax.cond( + idx == 0, + skip_correction, + correction, + output, + softmax_aux, + output_per_step, + softmax_aux_per_step, + ) + + return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux) + + carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux) + if helper.use_scanloop(): + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for i in range(0, cp_size): + carry = scan_kv_block(i, carry) + (_, _, _, output, softmax_aux) = carry + + softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1)) + + return output.astype(q.dtype), softmax_aux, rng_state + + return mesh, fwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnStripedFwdPrimitive) + + +class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Striped Ring Attention Backward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + arg_shardings = tuple(arg.sharding for arg in arg_infos) + # dq, dk, dv, dbias sharding = q, k, v, bias sharding + out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + def bwd_impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + ): + + if q_segment_ids.size == 0 or kv_segment_ids.size == 0: + raise ValueError("THD + ring attn only supports passing seqment_ids/pos") + + _not_used = jnp.zeros(0, dtype=output.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + if not config.qkv_layout.is_qkvpacked(): + subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked()) + else: + subblock_config = config + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + dq = jnp.zeros_like(q) + dkv = jnp.zeros_like(kv) + dbias = jnp.zeros_like(bias) + + def scan_kv_block(_idx, carry): + kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry + + # Start communication that feeds the next iteration. + # We further combine the tensors to improve overlap. + kv_dkv = jnp.stack([kv, dkv]) + kv_dkv = helper.permute_kv(kv_dkv, cp_perm) + kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) + kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) + + def compute(): + dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + config=subblock_config, + ) + return dq_per_step, dkv_per_step, dbias_per_step + + dq_per_step, dkv_per_step, dbias_per_step = compute() + + kv_next, dkv = jnp.unstack(kv_dkv) + dq += dq_per_step + dkv += dkv_per_step + if config.attn_bias_type is not AttnBiasType.NO_BIAS: + dbias = dbias + dbias_per_step + + return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias) + + carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias) + if helper.use_scanloop(): + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for idx in range(cp_size): + carry = scan_kv_block(idx, carry) + (_, _, _, dq, dkv, dbias) = carry + + # Final permute to put gradients back to their final resting place. + dkv = helper.permute_kv(dkv, cp_perm) + + global_dbias = dbias + if config.attn_bias_type is not AttnBiasType.NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) + + dk, dv = helper.unstack_kv(dkv) + return dq, dk, dv, global_dbias + + return mesh, bwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnStripedBwdPrimitive) + + def _maybe_context_parallel_axis(cp_axis: str): if not cp_axis: gmr = global_mesh_resource() @@ -2151,9 +2427,9 @@ def fused_attn_fwd( bias: Optional[jnp.ndarray], sequence_descriptor: SequenceDescriptor, seed: Optional[jnp.ndarray], - attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, - qkv_layout: NVTE_QKV_Layout, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, scaling_factor: float, dropout_probability: float, is_training: bool, @@ -2184,9 +2460,9 @@ def fused_attn_fwd( kv_seq_offsets (jnp.ndarray): The offsets in the sequence dim for the query, with shape [batch + 1,]. seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. @@ -2205,22 +2481,23 @@ def fused_attn_fwd( # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) - match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - assert ( - len(qkv) == 2 - ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - assert ( - len(qkv) == 3 - ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = qkv - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used, _not_used] + elif qkv_layout.is_kvpacked(): + assert ( + len(qkv) == 2 + ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used] + elif qkv_layout.is_separate(): + assert ( + len(qkv) == 3 + ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = qkv + else: + raise ValueError(f"Unknown {qkv_layout=}") + + if attn_bias_type == AttnBiasType.NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) @@ -2242,7 +2519,11 @@ def fused_attn_fwd( case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive case CPStrategy.RING: - primitive = FusedRingAttnFwdPrimitive.outer_primitive + # We must use stripe attention for THD-RING + if qkv_layout.is_thd(): + primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive + else: + primitive = FusedRingAttnFwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) return primitive.bind( @@ -2262,9 +2543,9 @@ def fused_attn_bwd( output: jnp.ndarray, doutput: jnp.ndarray, sequence_descriptor: SequenceDescriptor, - attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, - qkv_layout: NVTE_QKV_Layout, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, scaling_factor: float, dropout_probability: float, is_training: bool, @@ -2296,9 +2577,9 @@ def fused_attn_bwd( The offsets in the sequence dim for the query, with shape [batch + 1,]. kv_seq_offsets (jnp.ndarray): The offsets in the sequence dim for the query, with shape [batch + 1,]. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. @@ -2319,22 +2600,23 @@ def fused_attn_bwd( # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) - match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - assert ( - len(qkv) == 2 - ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - assert ( - len(qkv) == 3 - ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = qkv - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used, _not_used] + elif qkv_layout.is_kvpacked(): + assert ( + len(qkv) == 2 + ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used] + elif qkv_layout.is_separate(): + assert ( + len(qkv) == 3 + ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = qkv + else: + raise ValueError(f"Unknown {qkv_layout=}") + + if attn_bias_type == AttnBiasType.NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) @@ -2356,10 +2638,12 @@ def fused_attn_bwd( case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive case CPStrategy.RING: - primitive = FusedRingAttnBwdPrimitive.outer_primitive + if qkv_layout.is_thd(): + primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive + else: + primitive = FusedRingAttnBwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) - *qkv_grads, bias_grad = primitive.bind( *qkv_for_primitive, bias, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 7447cd1871..bc8c2c9aeb 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -229,6 +229,10 @@ static void FusedAttnForwardImpl( if (is_ragged) { auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); + + // Memset to 0xF0 for filling large negative numbers + auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads; + cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream); } /* Output tensors */ From a3fdb288a8aceb5fc39914d8e4053ea55c8cb033 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sun, 2 Mar 2025 20:01:25 -0800 Subject: [PATCH 152/239] WIP: enable/disable certain cases for fused attn Signed-off-by: Charlene Yang --- .../common/fused_attn/fused_attn.cpp | 5 +- transformer_engine/pytorch/attention.py | 96 +------------------ 2 files changed, 6 insertions(+), 95 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6f9e5f4eb3..3d54354815 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -279,7 +279,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90600 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0)) && + dropout == 0.0) || + (cudnn_runtime_version >= 90701)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || @@ -294,7 +295,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && - cudnn_runtime_version >= 90700)) && + cudnn_runtime_version >= 90701)) && // sliding window // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3cf3945adb..bb4d29e11c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -528,7 +528,6 @@ def get_attention_backend( # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 - # | FP8 | non-paged | sm89+ | bshd,sbhd,thd | # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 # Flash v3 | FP16/BF16 | non-paged/paged | sm80 | bshd,sbhd,thd | >= 1 # | FP8 | non-paged/paged | sm80 | thd | >= 1 @@ -544,8 +543,9 @@ def get_attention_backend( if _flash_attn_3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") use_flash_attention_3 = False - if use_fused_attention and inference_params.is_paged: - logger.debug("Disabling FusedAttention for FP8 paged attention") + if use_fused_attention: + # TODO(cyang): enable fused attn for FP8 non-paged + logger.debug("Disabling FusedAttention for FP8 KV caching") use_fused_attention = False if use_unfused_attention: logger.debug("Disabling UnfusedAttention for FP8 attention") @@ -6939,35 +6939,6 @@ def forward( ) else: with self.attention_dropout_ctx(): - print( - f"{max_seqlen_q=}", - f"{max_seqlen_kv=}", - f"{cu_seqlens_q=}", - f"{cu_seqlens_kv=}", - f"{cu_seqlens_q_padded=}", - f"{cu_seqlens_kv_padded=}", - f"{page_table_k=}", - f"{page_table_v=}", - f"{query_layer.shape}", - f"{key_layer.shape}", - f"{value_layer.shape}", - f"{qkv_dtype=}", - f"{core_attention_bias=}", - f"{self.softmax_scale=}", - f"{self.attention_dropout if self.training else 0.0=}", - f"{fast_zero_fill=}", - f"{qkv_layout=}", - f"{core_attention_bias_type=}", - f"{attn_mask_type=}", - f"{window_size=}", - f"{None=}", # rng_gen - f"{fused_attention_backend=}", - f"{use_FAv2_bwd=}", - f"{fp8=}", - f"{fp8_meta=}", - f"{quantizers=}", - f"{self.deterministic=}", - ) output = FusedAttnFunc.apply( self.training, max_seqlen_q, @@ -7704,15 +7675,6 @@ def forward( for x in [query_layer, key_layer, value_layer] ] - if query_layer.shape[0] == 2: - print("bbbbbbbbbbbbbbbbbb") - print("q", query_layer[0, 0, 0, :4]) - print("k", key_layer[0, 0, 0, :4]) - print("v", value_layer[0, 0, 0, :4]) - print("q", query_layer[1, 0, 0, :4]) - print("k", key_layer[1, 0, 0, :4]) - print("v", value_layer[1, 0, 0, :4]) - print("bbbbbbbbbbbbbbbbbb") ( key_layer, value_layer, @@ -7730,58 +7692,6 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None - if query_layer.shape[0] >= 7: - # print('q', query_layer[0,0,0,:4]) - # print('k', key_layer[0,0,0,:4]) - # print('v', value_layer[0,0,0,:4]) - # print('q', query_layer[1,6,0,:4]) - # print('k', key_layer[1,6,0,:4]) - # print('v', value_layer[1,6,0,:4]) - # print('xxxxxxx') - # print('q', query_layer[5,28,0,:4]) - # print('k', key_layer[5,28,0,:4]) - # print('v', value_layer[5,28,0,:4]) - # print('q', query_layer[6,15,0,:4]) - # print('k', key_layer[6,15,0,:4]) - # print('v', value_layer[6,15,0,:4]) - print("xxxxxxx") - # print('q', query_layer[5,26,0,:4]) - # print('k', key_layer[5,26,0,:4]) - # print('v', value_layer[5,26,0,:4]) - # print('q', query_layer[6,13,0,:4]) - # print('k', key_layer[6,13,0,:4]) - # print('v', value_layer[6,13,0,:4]) - # torch.save(query_layer, 'full_q.pt') - # torch.save(key_layer, 'full_k.pt') - # torch.save(value_layer, 'full_v.pt') - print("q", query_layer[5, 35:37, 0, :4]) - print("k", key_layer[5, 35:37, 0, :4]) - print("v", value_layer[5, 35:37, 0, :4]) - print("q", query_layer[6, 22:24, 0, :4]) - print("k", key_layer[6, 22:24, 0, :4]) - print("v", value_layer[6, 22:24, 0, :4]) - if query_layer.shape[0] == 2: - # torch.save(query_layer, 'partial_q.pt') - # torch.save(key_layer, 'partial_k.pt') - # torch.save(value_layer, 'partial_v.pt') - print("q", query_layer[0, 0, 0, :4]) - print("k", key_layer[0, 36, 0, :4]) - print("v", value_layer[0, 36, 0, :4]) - print("q", query_layer[1, 0, 0, :4]) - print("k", key_layer[1, 23, 0, :4]) - print("v", value_layer[1, 23, 0, :4]) - # print('q', query_layer[0,0,0,:4]) - # print('k', key_layer[0,26,0,:4]) - # print('v', value_layer[0,26,0,:4]) - # print('q', query_layer[1,0,0,:4]) - # print('k', key_layer[1,13,0,:4]) - # print('v', value_layer[1,13,0,:4]) - # print('q', query_layer[0,0,0,:4]) - # print('k', key_layer[0,35:37,0,:4]) - # print('v', value_layer[0,35:37,0,:4]) - # print('q', query_layer[1,0,0,:4]) - # print('k', key_layer[1,22:24,0,:4]) - # print('v', value_layer[1,22:24,0,:4]) # get accurate qkv_layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( From fc5a7e9c74b41aed50476f5216fe364cf19391cc Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sun, 2 Mar 2025 20:40:25 -0800 Subject: [PATCH 153/239] WIP: small fixes for lint and cg Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 36 ++++++++-------- transformer_engine/pytorch/attention.py | 46 +++++++++++---------- transformer_engine/pytorch/inference.py | 8 +--- 3 files changed, 43 insertions(+), 47 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index e9f9b76245..5a9cb7519c 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -384,7 +384,7 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("model", model_configs_infer.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True]) @@ -392,7 +392,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda logger = logging.getLogger("test_paged_attn") num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 config = model_configs_infer[model] - if backend == "FlashAttention" and _flash_attn_3_is_installed: + if backend == "FlashAttention" and not _flash_attn_3_is_installed: config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 @@ -413,7 +413,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda page_size = None total_num_pages = None if is_paged: - page_size = 256 if backend == "FlashAttention" and _flash_attn_3_is_installed else 16 + page_size = 256 if backend == "FlashAttention" and not _flash_attn_3_is_installed else 16 config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) else: @@ -438,13 +438,11 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda is_paged=is_paged, page_size=page_size, total_num_pages=total_num_pages, - num_heads_q=config.num_heads, - head_dim_q=config.head_dim_qk, max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, ) for layer_number in range(1, num_layers + 1): - inference_params.allocate_memory(layer_number, qkv_format) + inference_params.allocate_memory(layer_number) # figure out supported backends inference_params_qkv_format = "bshd" @@ -546,18 +544,18 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): - model = [ - make_graphed_callables( - model[i], - sample_args, - num_warmup_iters=10, - fp8_enabled=is_fp8, - sample_kwargs=sample_kwargs, - fp8_recipe=fp8_recipe, - ) - for i in range(num_layers) - ] + #with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): + model = [ + make_graphed_callables( + model[i], + sample_args, + num_warmup_iters=10, + fp8_enabled=is_fp8, + sample_kwargs=sample_kwargs, + fp8_recipe=fp8_recipe, + ) + for i in range(num_layers) + ] sim.reset() inference_params.reset() @@ -712,6 +710,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention" and _flash_attn_3_is_installed: + if backend == "FlashAttention" and not _flash_attn_3_is_installed: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bb4d29e11c..8126c7244b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -448,7 +448,7 @@ def get_attention_backend( global _flash_attn_version_required, _flash_attn_max_version # , _use_flash_attn_3 # get q/kv format - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) @@ -2062,7 +2062,7 @@ def forward( # flash_attn_fwd = _flash_attn_varlen_fwd_v3 # else: # flash_attn_fwd = _flash_attn_fwd_v3 - flash_attn_fwd = _flash_attn_fwd_v3 + flash_attn_fwd = _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) else: if qkv_format == "thd": @@ -3014,10 +3014,11 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if use_fa_v3: - if ctx.qkv_format == "thd": - flash_attn_bwd = _flash_attn_varlen_bwd_v3 - else: - flash_attn_bwd = _flash_attn_bwd_v3 + #if ctx.qkv_format == "thd": + # flash_attn_bwd = _flash_attn_varlen_bwd_v3 + #else: + # flash_attn_bwd = _flash_attn_bwd_v3 + flash_attn_bwd = _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment fa_backward_kwargs["deterministic"] = ctx.deterministic else: if ctx.qkv_format == "thd": @@ -4079,10 +4080,11 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if use_fa_v3: - if ctx.qkv_format == "thd": - flash_attn_bwd = _flash_attn_varlen_bwd_v3 - else: - flash_attn_bwd = _flash_attn_bwd_v3 + #if ctx.qkv_format == "thd": + # flash_attn_bwd = _flash_attn_varlen_bwd_v3 + #else: + # flash_attn_bwd = _flash_attn_bwd_v3 + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: if ctx.qkv_format == "thd": @@ -4610,10 +4612,11 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if use_fa_v3: - if ctx.qkv_format == "thd": - flash_attn_bwd = _flash_attn_varlen_bwd_v3 - else: - flash_attn_bwd = _flash_attn_bwd_v3 + #if ctx.qkv_format == "thd": + # flash_attn_bwd = _flash_attn_varlen_bwd_v3 + #else: + # flash_attn_bwd = _flash_attn_bwd_v3 + flash_attn_bwd = _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: @@ -5286,7 +5289,7 @@ def forward( ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" # get q_format and kv_format for training and inference - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) @@ -5893,7 +5896,7 @@ def forward( ] # get accurate batch_size, max_seqlen and cu_seqlens - batch_size, context_len, total_tokens = None, None, None + batch_size, context_len = None, None if inference_params is None: if qkv_format in ["sbhd", "bshd"]: batch_size = query_layer.shape[0] @@ -5970,7 +5973,7 @@ def forward( else: if qkv_format in ["sbhd_2bshd", "bshd"]: # q is in bshd in both cases from conversion above or the original input - batch_size, context_len, num_heads, head_dim = query_layer.shape + batch_size, context_len = query_layer.shape[:2] cu_seqlens_q = cu_seqlens_q[: batch_size + 1] cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] # convert from bshd to thd_2bshd @@ -6049,14 +6052,14 @@ def forward( # | | bshd/sbhd/thd + padding fa_optional_forward_args_thd = [] if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: - func = flash_attn_func if not use_fa_v3 else flash_attn_func_v3 + func = flash_attn_func if not use_fa_v3 else flash_attn_func_v3 # pylint: disable=possibly-used-before-assignment else: if not use_fa_v3: func = flash_attn_varlen_func elif inference_params is None: - func = flash_attn_varlen_func_v3 + func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment else: - func = flash_attn_with_kvcache_v3 + func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment if not use_fa_v3 or inference_params is None: fa_optional_forward_args_thd.append(cu_seqlens_q) fa_optional_forward_args_thd.append(cu_seqlens_kv) @@ -6153,7 +6156,7 @@ def convert_to_torch_float8(tensor, dtype): causal="causal" in attn_mask_type, **fa_3_optional_forward_kwargs, ) - if isinstance(output, List) or isinstance(output, Tuple): + if isinstance(output, (List, Tuple)): output = output[0] except TypeError as e: if _flash_attn_3_0_0_beta: @@ -7612,6 +7615,7 @@ def forward( "bshd", "thd", ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" + batch_size = None if qkv_format in ["sbhd", "bshd"]: assert all( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 8bffb5c2cf..a91573cc13 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -121,10 +121,6 @@ class DotProductAttention: Total number of pages in the KV cache. Required for is_paged = True. page_size: int, default = None Page size of the KV cache. Required for is_paged = True. - num_heads_q: int, default = None - Number of attention heads in queries - head_dim_q: int, default = None - Head size for queries. Required for qkv_format = 'thd'. max_ctx_len: int, default = None Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. qkv_format: str, default = "bshd" @@ -144,8 +140,6 @@ def __init__( is_paged: bool = False, total_num_pages: int = None, page_size: int = None, - num_heads_q: int = None, - head_dim_q: int = None, max_ctx_len: int = None, qkv_format: str = "bshd", cache_manager: KVCacheManager = None, @@ -242,7 +236,7 @@ def __repr__(self) -> str: f"head_dim_v={self.head_dim_v}" ) - def allocate_memory(self, layer_number: int, qkv_format: str): + def allocate_memory(self, layer_number: int): """ Allocate memory for the cache. For layer layer_number, - NonPagedKVCacheManager: From e1898e17b1623ba758f542d94389073a58b544fb Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sun, 2 Mar 2025 21:40:25 -0800 Subject: [PATCH 154/239] WIP: minor fixes for attn/infer Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_paged_attn.py | 13 ++-- tests/pytorch/test_numerics.py | 17 +++--- transformer_engine/pytorch/attention.py | 6 +- transformer_engine/pytorch/inference.py | 68 ++++++++++----------- 4 files changed, 47 insertions(+), 57 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 5a9cb7519c..ec486917fd 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -441,8 +441,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda max_ctx_len=config.max_ctx_len, qkv_format=qkv_format, ) - for layer_number in range(1, num_layers + 1): - inference_params.allocate_memory(layer_number) + if module == "DotProductAttention": + for layer_number in range(1, num_layers + 1): + inference_params.allocate_memory(layer_number) # figure out supported backends inference_params_qkv_format = "bshd" @@ -649,11 +650,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): for m in model: incremental_output = m( - *( - incremental_output - if isinstance(incremental_output, List) - else incremental_output - ), + *incremental_output + if isinstance(incremental_output, List) + else incremental_output, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 410f501f63..77bd528e62 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -78,7 +78,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq model_configs_inference = { # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), + "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), } backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] @@ -2029,14 +2029,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", all_boolean) +@pytest.mark.parametrize("use_RoPE", [False]) #all_boolean) @pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("backend", backends_inference) @pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.parametrize("is_cuda_graph", [False, True]) def test_kv_cache_accuracy( - dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged, is_cuda_graph + dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged ): reset_rng_states() @@ -2102,17 +2101,12 @@ def test_kv_cache_accuracy( head_dim_k=head_size, dtype=dtype, is_paged=is_paged, - total_num_pages=4, + total_num_pages=int(B_max*S_max/256), page_size=256, - is_cuda_graph=is_cuda_graph, - num_heads_q=H, - head_dim_q=head_size, ) rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") - inference_params.step_dict = OrderedDict(zip(list(range(B)), [1] * B)) - input = torch.randn((S, B, D), dtype=dtype, device="cuda") if input_format == "bshd": input = input.transpose(0, 1).contiguous() @@ -2123,7 +2117,10 @@ def test_kv_cache_accuracy( full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) # Incrementaly generate outputs using KV-cache + step_dict = OrderedDict(zip(list(range(B)), [1] * B)) for i in range(S): + inference_params.pre_step(step_dict) + if input_format == "sbhd": incremental_input = input[i].view(1, B, D) else: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8126c7244b..b358dd5ca9 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8592,7 +8592,7 @@ def forward( inference_params is not None and self.layer_number not in inference_params.cache_manager.cache ): - inference_params.allocate_memory(self.layer_number, self.qkv_format) + inference_params.allocate_memory(self.layer_number) # ====================== # Query, Key, and Value @@ -8762,8 +8762,8 @@ def forward( f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE." ) - # sequence_start = inference_params.get_seqlens_pre_step() - sequence_start = inference_params.seqlens[0] + sequence_start = inference_params.get_seqlens_pre_step() + #sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index a91573cc13..747a9ab78e 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -87,14 +87,14 @@ class MultiHeadAttention: inference_params.allocate_memory(self.layer_number) class DotProductAttention: if inference_params is not None: - q, k_cache, v_cache, qkv_format = inference_params.step( - new_q, new_k, new_v, qkv_format) - output = attention(q, k_cache, v_cache, new_qkv_format) + k_cache, v_cache, new_qkv_format = inference_params.step( + new_k, new_v, qkv_format) + output = attention(new_q, k_cache, v_cache, new_qkv_format) InferenceParams supports cache_qkv_format = "bshd" only, and the step() function may change qkv_format depending on the attention backend. - Backend | Before step() | After step() + backend | qkv_format | new_qkv_format ------------------------------------------------------------------------------------ FusedAttention | {bshd, sbhd, thd} | {bshd_2bshd, sbhd_2bshd, thd_2bshd} FlashAttention | {bshd, sbhd, thd} | {bshd, sbhd, thd} @@ -207,8 +207,17 @@ def __init__( self.step_dict = OrderedDict() self.batch_size = 0 - self.cu_seqlens_q = None - self.cu_seqlens_kv = None + self.cu_seqlens_q = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_kv = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + def reset(self): """Reset InferenceParams state""" @@ -248,17 +257,6 @@ def allocate_memory(self, layer_number: int): """ self.cache_manager.allocate_memory(layer_number) - self.cu_seqlens_q = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - self.cu_seqlens_kv = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - def pre_step( self, step_dict: OrderedDict, @@ -288,7 +286,7 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths for current iteration before adding step_dict.values""" - return self.sequences_pre + return torch.Tensor(list(self.sequences_pre.values())).to(dtype=torch.int32, device="cpu") def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -420,8 +418,18 @@ def __init__( # cache tensors, cache[layer_number] = (k_cache, v_cache) self.cache = {} # track sequence indices in the batch in order to re-index k_cache and v_cache - self.batch_indices = None - self.batch_indices_post = None + self.batch_indices = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + # always in [0, ..., b-1] fashion due to reindexing + self.batch_indices_post = torch.range( + 0, + self.max_batch_size - 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) def allocate_memory(self, layer_number): """Allocate memory for the cache""" @@ -443,19 +451,6 @@ def allocate_memory(self, layer_number): ) self.cache[layer_number] = (k_cache, v_cache) - self.batch_indices = torch.zeros( - self.max_batch_size, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - # always in [0, ..., b-1] fashion due to reindexing - self.batch_indices_post = torch.range( - 0, - self.max_batch_size - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - def pre_step( self, step_dict: OrderedDict, @@ -618,7 +613,9 @@ def __init__( # allocated pages, {seq_id: [page_id,...]} self.allocated_pages = defaultdict(list) # page table, [batch_size, max_pages_per_seq] - self.page_table = None + self.page_table = torch.zeros( + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + ) def reset(self): """Reset cache manager state""" @@ -649,9 +646,6 @@ def allocate_memory(self, layer_number): ) self.cache[layer_number] = (k_cache, v_cache) - self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" - ) def print_cache(self): """Print KV cache status""" From 96d7d79c3efb3432b79f03b5830007e4f3ef171f Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Sun, 2 Mar 2025 22:11:34 -0800 Subject: [PATCH 155/239] WIP: fix CP Signed-off-by: Charlene Yang --- transformer_engine/pytorch/attention.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b358dd5ca9..eda9f184b1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3745,6 +3745,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4237,6 +4238,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4768,6 +4770,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -5606,6 +5609,7 @@ def get_qkv_layout( Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. """ + #q, k, v = [x.contiguous() for x in [q, k, v]] check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" if "_2" in qkv_format: From 90dcc6801c3eb3f20aa3702e61b6ae4dec6f8694 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 3 Mar 2025 09:25:44 -0800 Subject: [PATCH 156/239] WIP: readd page info to FADescriptor_v1 Signed-off-by: Charlene Yang --- .../fused_attn_f16_arbitrary_seqlen.cu | 18 ++++++++--------- .../common/fused_attn/fused_attn_fp8.cu | 20 +++++++++---------- transformer_engine/common/fused_attn/utils.h | 12 +++++------ 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 968fef5cb5..6bf9bffab9 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -102,10 +102,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( try { FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d_qk, d_v, - //num_pages_k, - //num_pages_v, - //page_size_k, - //page_size_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, scaling_factor, is_training, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, true, tensorType, @@ -519,11 +519,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( try { FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d_qk, d_v, - //0, - //0, - //0, - //0, - 1, 1, bias_b, bias_h, scaling_factor, true, dropout_probability, + 0, + 0, + 0, + 0, + 0, 0, bias_b, bias_h, scaling_factor, true, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, tensorType, tensorType}; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 9beadd0a2d..72fd91de11 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1673,11 +1673,11 @@ void fused_attn_fp8_fwd_impl_v1( try { FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d, d, - //0, - //0, - //0, - //0, - 1, 1, bias_b, bias_h, scaling_factor, is_training, + 0, + 0, + 0, + 0, + 0, 0, bias_b, bias_h, scaling_factor, is_training, dropout_probability, layout, bias_type, mask_type, 0, 0, true, fwd_tensor_type, fwd_tensor_type}; @@ -1959,11 +1959,11 @@ void fused_attn_fp8_bwd_impl_v1( try { FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d, d, - //0, - //0, - //0, - //0, - 1, 1, bias_b, bias_h, scaling_factor, true, dropout_probability, + 0, + 0, + 0, + 0, + 0, 0, bias_b, bias_h, scaling_factor, true, dropout_probability, layout, bias_type, mask_type, 0, 0, false, fwd_tensor_type, bwd_tensor_type}; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 8734bb3af1..63da03acf3 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -93,10 +93,10 @@ struct FADescriptor_v1 { std::int64_t s_kv; std::int64_t d_qk; std::int64_t d_v; - //std::int64_t num_pages_k; - //std::int64_t num_pages_v; - //std::int64_t page_size_k; - //std::int64_t page_size_v; + std::int64_t num_pages_k; + std::int64_t num_pages_v; + std::int64_t page_size_k; + std::int64_t page_size_v; std::int64_t max_pages_per_seq_k; std::int64_t max_pages_per_seq_v; std::int64_t bias_b; @@ -115,12 +115,12 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, - //num_pages_k, num_pages_v, page_size_k, page_size_v, + num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, - //rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, + rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, From 53fe2c964ab57f75fbc4cdf8fbc6f6bf8b1acfa3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:26:21 +0000 Subject: [PATCH 157/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 10 ++-- tests/pytorch/test_numerics.py | 8 ++- .../fused_attn_f16_arbitrary_seqlen.cu | 51 ++++++++++++++++--- .../common/fused_attn/fused_attn_fp8.cu | 50 +++++++++++++++--- transformer_engine/common/fused_attn/utils.h | 21 ++++---- transformer_engine/pytorch/attention.py | 32 +++++++----- transformer_engine/pytorch/inference.py | 2 - 7 files changed, 125 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index ec486917fd..998f56d0e4 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -545,7 +545,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - #with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): + # with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): model = [ make_graphed_callables( model[i], @@ -650,9 +650,11 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): for m in model: incremental_output = m( - *incremental_output - if isinstance(incremental_output, List) - else incremental_output, + *( + incremental_output + if isinstance(incremental_output, List) + else incremental_output + ), cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 77bd528e62..89cd793a35 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2029,14 +2029,12 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", [False]) #all_boolean) +@pytest.mark.parametrize("use_RoPE", [False]) # all_boolean) @pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("backend", backends_inference) @pytest.mark.parametrize("is_paged", [False, True]) -def test_kv_cache_accuracy( - dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged -): +def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): reset_rng_states() if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: @@ -2101,7 +2099,7 @@ def test_kv_cache_accuracy( head_dim_k=head_size, dtype=dtype, is_paged=is_paged, - total_num_pages=int(B_max*S_max/256), + total_num_pages=int(B_max * S_max / 256), page_size=256, ) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6bf9bffab9..2ce93f196a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -101,14 +101,31 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d_qk, d_v, + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, - max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - scaling_factor, is_training, dropout_probability, layout, bias_type, - mask_type, window_size_left, window_size_right, true, tensorType, + max_pages_per_seq_k, + max_pages_per_seq_v, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + window_size_left, + window_size_right, + true, + tensorType, tensorType}; namespace fe = cudnn_frontend; @@ -518,14 +535,32 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d_qk, d_v, + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, 0, 0, 0, 0, - 0, 0, bias_b, bias_h, scaling_factor, true, dropout_probability, - layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, tensorType, tensorType}; + 0, + 0, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + window_size_left, + window_size_right, + deterministic, + tensorType, + tensorType}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 72fd91de11..eacd8b53b4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1672,14 +1672,32 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); try { - FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d, d, + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d, + d, 0, 0, 0, 0, - 0, 0, bias_b, bias_h, scaling_factor, is_training, - dropout_probability, layout, bias_type, mask_type, 0, 0, true, - fwd_tensor_type, fwd_tensor_type}; + 0, + 0, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + 0, + 0, + true, + fwd_tensor_type, + fwd_tensor_type}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1958,13 +1976,31 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); try { - FADescriptor_v1 descriptor{b, h, hg, s_q, s_kv, d, d, + FADescriptor_v1 descriptor{b, + h, + hg, + s_q, + s_kv, + d, + d, + 0, + 0, + 0, + 0, 0, 0, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, 0, 0, - 0, 0, bias_b, bias_h, scaling_factor, true, dropout_probability, - layout, bias_type, mask_type, 0, 0, false, fwd_tensor_type, + false, + fwd_tensor_type, bwd_tensor_type}; namespace fe = cudnn_frontend; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 63da03acf3..30702a875d 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -114,17 +114,16 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, - num_pages_k, num_pages_v, page_size_k, page_size_v, - max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, - dropoutProbability, layout, mask_type, window_size_left, window_size_right, - deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, - rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, - rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, - rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, + window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, + rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, + rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, + rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index eda9f184b1..7d45d485eb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2062,7 +2062,9 @@ def forward( # flash_attn_fwd = _flash_attn_varlen_fwd_v3 # else: # flash_attn_fwd = _flash_attn_fwd_v3 - flash_attn_fwd = _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment + flash_attn_fwd = ( + _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment + ) fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) else: if qkv_format == "thd": @@ -3014,11 +3016,13 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if use_fa_v3: - #if ctx.qkv_format == "thd": + # if ctx.qkv_format == "thd": # flash_attn_bwd = _flash_attn_varlen_bwd_v3 - #else: + # else: # flash_attn_bwd = _flash_attn_bwd_v3 - flash_attn_bwd = _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) fa_backward_kwargs["deterministic"] = ctx.deterministic else: if ctx.qkv_format == "thd": @@ -4081,9 +4085,9 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if use_fa_v3: - #if ctx.qkv_format == "thd": + # if ctx.qkv_format == "thd": # flash_attn_bwd = _flash_attn_varlen_bwd_v3 - #else: + # else: # flash_attn_bwd = _flash_attn_bwd_v3 flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic @@ -4614,11 +4618,13 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if use_fa_v3: - #if ctx.qkv_format == "thd": + # if ctx.qkv_format == "thd": # flash_attn_bwd = _flash_attn_varlen_bwd_v3 - #else: + # else: # flash_attn_bwd = _flash_attn_bwd_v3 - flash_attn_bwd = _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: @@ -5609,7 +5615,7 @@ def get_qkv_layout( Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. """ - #q, k, v = [x.contiguous() for x in [q, k, v]] + # q, k, v = [x.contiguous() for x in [q, k, v]] check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" if "_2" in qkv_format: @@ -6056,7 +6062,9 @@ def forward( # | | bshd/sbhd/thd + padding fa_optional_forward_args_thd = [] if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: - func = flash_attn_func if not use_fa_v3 else flash_attn_func_v3 # pylint: disable=possibly-used-before-assignment + func = ( + flash_attn_func if not use_fa_v3 else flash_attn_func_v3 + ) # pylint: disable=possibly-used-before-assignment else: if not use_fa_v3: func = flash_attn_varlen_func @@ -8767,7 +8775,7 @@ def forward( ) sequence_start = inference_params.get_seqlens_pre_step() - #sequence_start = inference_params.seqlens[0] + # sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 747a9ab78e..d5b3035f5c 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -218,7 +218,6 @@ def __init__( device=torch.cuda.current_device(), ) - def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() @@ -646,7 +645,6 @@ def allocate_memory(self, layer_number): ) self.cache[layer_number] = (k_cache, v_cache) - def print_cache(self): """Print KV cache status""" used_pages = [self.get_page_count(seq) for seq in self.sequences] From f250dce80ac0c06a511375c4b3b3e95ddd98d2bc Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 3 Mar 2025 09:41:39 -0800 Subject: [PATCH 158/239] minor tweak to test_numerics.py Signed-off-by: Charlene Yang --- tests/pytorch/test_numerics.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 89cd793a35..e56bb8868c 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2127,9 +2127,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - seqlens_kv = (i + 1) * torch.ones(B, dtype=torch.int32, device="cuda") - cu_seqlens_kv = torch.zeros(B + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) + cu_seqlens_kv = cu_seqlens_q.clone() mask_type = "padding" kwargs = {} From e5c0e402b0bd8e91b714a2bf6e5393840b453919 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 3 Mar 2025 11:40:44 -0800 Subject: [PATCH 159/239] fix 9.5/9.7 sq/skv + mask logic Signed-off-by: Charlene Yang --- .../common/fused_attn/fused_attn.cpp | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3d54354815..b5f2a4386a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -272,15 +272,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0) || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && + max_seqlen_q <= max_seqlen_kv)) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) (cudnn_runtime_version >= 90600 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - (cudnn_runtime_version >= 90701)) && + // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} + // for any q_format/kv_format, and paged/non-paged + (cudnn_runtime_version >= 90700 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + max_seqlen_q <= max_seqlen_kv)))) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || @@ -295,7 +308,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && - cudnn_runtime_version >= 90701)) && + cudnn_runtime_version >= 90700)) && // sliding window // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && From fc1b91c2733e52335dc19622af7153c2146286eb Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Mon, 3 Mar 2025 14:39:55 -0800 Subject: [PATCH 160/239] Launch GEMM on compute_stream which has low priority. (#1522) Signed-off-by: Vasudevan Rengasamy --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d988de6f66..3dd5f7228b 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -262,6 +262,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size @@ -288,14 +289,17 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te assert(pre_gelu_out.numel() == 0); // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch if (_comm_launch_event) - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, - stream_main); + _stream_compute[0]); _ub_comm->sms = ori_sms; NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + } // CommOverlapBase::bulk_overlap /* From bca1f58979dc23bbaa563725a829efa531eb88af Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 3 Mar 2025 15:54:11 -0800 Subject: [PATCH 161/239] clean up Signed-off-by: Charlene Yang --- tests/pytorch/fused_attn/test_fused_attn.py | 4 +- tests/pytorch/fused_attn/test_paged_attn.py | 100 +++-- transformer_engine/pytorch/attention.py | 366 ++++++++---------- .../pytorch/csrc/extensions/attention.cu | 22 +- transformer_engine/pytorch/csrc/kv_cache.cuh | 9 +- transformer_engine/pytorch/inference.py | 129 +++--- 6 files changed, 276 insertions(+), 354 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 1ae3e3bd7f..8f20b0d4cf 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -139,9 +139,9 @@ def _get_attention_backends( pad_between_seqs: bool = False, context_parallel: bool = False, deterministic: bool = False, - is_training: bool = True, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + is_training: bool = True, inference_params: Optional[InferenceParams] = None, ) -> Tuple[List, List]: """Check if what attention backends support a model configuration""" @@ -196,9 +196,9 @@ def test(): attention_dropout=config.dropout_p, context_parallel=context_parallel, deterministic=deterministic, - is_training=is_training, fp8=fp8, fp8_meta=fp8_meta, + is_training=is_training, inference_params=inference_params, ) _, _, flash_attention_backend, fused_attention_backend, _, available_backends = ( diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 998f56d0e4..a9c4736918 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -53,7 +53,7 @@ "infer_0": ModelConfig( 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), - # "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6), + # "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -92,7 +92,6 @@ def __init__( self.context_lens = torch.randint( 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" ) - # self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate gen lengths in Exponential distribution gen_dist = Exponential(1 / self.max_gen_len) @@ -101,7 +100,6 @@ def __init__( dtype=torch.int32, device="cpu" ) self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu") - # self.gen_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu") # simulate arrival times in Poisson distribution if poisson_rate is None: @@ -111,7 +109,6 @@ def __init__( self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to( dtype=torch.int32, device="cpu" ) - # self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu") self.last_arrival = self.arrival_times.max().item() # initialize tensors @@ -170,6 +167,7 @@ def add_new_seqs(self, new_seq_ids): self.t_ctx_lens = torch.cat([self.t_ctx_lens, self.context_lens[new_seq_ids]], dim=0) gen_lens = torch.Tensor([0] * len(new_seq_ids)).to(dtype=torch.int32, device="cpu") self.t_gen_lens = torch.cat([self.t_gen_lens, gen_lens], dim=0) + # append new seqs' ctx_lens to step_lens self.step_lens = torch.cat([self.step_lens, self.context_lens[new_seq_ids]], dim=0) @@ -181,6 +179,7 @@ def remove_finished(self): self.t_seq_ids = self.t_seq_ids[~finished] self.t_ctx_lens = self.t_ctx_lens[~finished] self.t_gen_lens = self.t_gen_lens[~finished] + # add ones for unfinished seqs to step_lens self.step_lens = torch.ones([len(self.t_seq_ids)], dtype=torch.int32, device="cpu") @@ -188,6 +187,7 @@ def step(self, dynamic_fill: bool = True): # remove finished seqs if self.t != 0: self.remove_finished() + # get allowed new seqs arrived_seq_ids = torch.where(self.arrival_times == self.t, True, False).nonzero().view(-1) queuing_seq_ids = torch.cat([self.delayed_seq_ids, arrived_seq_ids], dim=0) @@ -202,8 +202,10 @@ def step(self, dynamic_fill: bool = True): else: new_seq_ids = queuing_seq_ids self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32) + # add new seqs to batch self.add_new_seqs(new_seq_ids) + # update batch variables self.t_batch_size = len(self.t_seq_ids) self.t_total_lens = self.t_ctx_lens + self.t_gen_lens @@ -242,26 +244,27 @@ def get_model( if module == "TransformerLayer": hidden_size = config.head_dim_qk * config.num_heads - model = [ - TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=4 * hidden_size, - num_attention_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - hidden_dropout=0.0, - attention_dropout=config.dropout_p, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - kv_channels=config.head_dim_qk, - self_attn_mask_type=attn_mask_type, - params_dtype=dtype, - attn_input_format=qkv_format, - ) - .cuda() - .eval() - for layer_number in range(1, num_layers + 1) - ] + with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): + model = [ + TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4 * hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type=attn_mask_type, + params_dtype=dtype, + attn_input_format=qkv_format, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] if module == "DotProductAttention": with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): model = [ @@ -288,6 +291,7 @@ def generate_args( qkv_format: str = "bshd", mode: str = "full_inputs", ): + # full inputs used as reference if mode == "full_inputs": warmup = False shapes = [] @@ -315,6 +319,7 @@ def generate_args( config.head_dim_v, ] ) + # sample args used for cuda graph warmup elif mode == "sample_args": warmup = True shapes = [] @@ -390,13 +395,6 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("is_fp8", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): logger = logging.getLogger("test_paged_attn") - num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 - config = model_configs_infer[model] - if backend == "FlashAttention" and not _flash_attn_3_is_installed: - config_max_seqlen_q = config.max_seqlen_q - config_max_seqlen_kv = config.max_seqlen_kv - config.max_seqlen_q = 256 - config.max_seqlen_kv = 256 fp8_recipe = recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.HYBRID, @@ -408,6 +406,15 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda fp8_meta = {} fp8_meta["recipe"] = fp8_recipe + config = model_configs_infer[model] + num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 + # flash-attn v2 requires page_size >= 256 + if backend == "FlashAttention" and not _flash_attn_3_is_installed: + config_max_seqlen_q = config.max_seqlen_q + config_max_seqlen_kv = config.max_seqlen_kv + config.max_seqlen_q = 256 + config.max_seqlen_kv = 256 + # create a real-life simulation max_batch_size = config.batch_size page_size = None @@ -545,7 +552,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sample_kwargs["max_seqlen_q"] = config.max_ctx_len sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv - # with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): model = [ make_graphed_callables( model[i], @@ -640,8 +646,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) - cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0) + cu_seqlens_kv = cu_seqlens_q.clone() step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist())) inference_params.pre_step(step_dict) if inference_params.is_paged: @@ -650,11 +655,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): for m in model: incremental_output = m( - *( - incremental_output - if isinstance(incremental_output, List) - else incremental_output - ), + *incremental_output, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, inference_params=inference_params, @@ -669,44 +670,31 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": - print(i, seq, sim.t_total_lens, sim.step_lens, token_index) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(incremental_output[i, token_index, :4]) torch.testing.assert_close( - # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - # incremental_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], - incremental_output[i, token_index, :], + incremental_output[i, sim.step_lens[i] - 1, :], atol=tol, rtol=tol, ) if qkv_format == "sbhd": - print(i, seq, sim.t_total_lens, sim.step_lens, token_index) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(incremental_output[token_index, i, :4]) torch.testing.assert_close( - # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - # incremental_output[:sim.step_lens[i] - 1, i, :], full_output[seq, sim.t_total_lens[i] - 1, :], - incremental_output[token_index, i, :], + incremental_output[sim.step_lens[i] - 1, i, :], atol=tol, rtol=tol, ) if qkv_format == "thd": - print("i ", i, seq, cu_seqlens_q) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(incremental_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( - # full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], - # incremental_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[cu_seqlens_q[i + 1] - 1, :], atol=tol, rtol=tol, ) + sim.t += 1 sim.t_gen_lens = sim.t_gen_lens + 1 + # last value in complete_times should be equal to sim.t sim.serving_times = sim.arrival_times + sim.request_delays sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7d45d485eb..52076c9a90 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -194,7 +194,6 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False -# _use_flash_attn_3 = False _flash_attn_3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install @@ -220,16 +219,8 @@ def _get_supported_versions(version_min, version_max): from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - # from flash_attn_3.flash_attn_interface import ( - # _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, - # ) - # from flash_attn_3.flash_attn_interface import ( - # _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, - # ) - _flash_attn_3_is_installed = True _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") - # _use_flash_attn_3 = True _attention_backends = { "attention_params": None, @@ -445,9 +436,7 @@ def get_attention_backend( # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. - global _flash_attn_version_required, _flash_attn_max_version # , _use_flash_attn_3 - - # get q/kv format + global _flash_attn_version_required, _flash_attn_max_version qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables @@ -457,10 +446,10 @@ def get_attention_backend( flash_attention_backend = None use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if not use_flash_attention_2 and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0") - if not use_flash_attention_3 and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0") + if (use_flash_attention_2 and _flash_attn_is_installed) or ( + use_flash_attention_3 and _flash_attn_is_installed + ): + logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") if not use_fused_attention: logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") if not use_unfused_attention: @@ -529,8 +518,8 @@ def get_attention_backend( # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 - # Flash v3 | FP16/BF16 | non-paged/paged | sm80 | bshd,sbhd,thd | >= 1 - # | FP8 | non-paged/paged | sm80 | thd | >= 1 + # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 + # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: if context_parallel: @@ -544,41 +533,44 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") use_flash_attention_3 = False if use_fused_attention: - # TODO(cyang): enable fused attn for FP8 non-paged logger.debug("Disabling FusedAttention for FP8 KV caching") use_fused_attention = False - if use_unfused_attention: - logger.debug("Disabling UnfusedAttention for FP8 attention") - use_unfused_attention = False else: - if use_fused_attention and pad_between_seqs: - use_fused_attention = False - logger.debug("Disabling FusedAttention for pad_between_seqs = True and KV caching") if q_format == "thd" and pad_between_seqs: - if use_flash_attention_2 and _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention 2 for pad_between_seqs = True and KV caching" - ) - if use_flash_attention_3 and _flash_attn_3_is_installed: - logger.debug( - "Disabling FlashAttention 3 for pad_between_seqs = True and KV caching" - ) - use_flash_attention = False - if inference_params.is_paged: - if use_flash_attention_2 and not _flash_attn_2_5_plus: logger.debug( - "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" + "Disabling all backends for pad_between_seqs = True and KV caching" ) + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if inference_params.is_paged: + if use_flash_attention_2 and inference_params.page_size < 256: + if _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 for page size < 256") use_flash_attention_2 = False + if use_flash_attention_2: + if not _flash_attn_is_installed: + _flash_attn_version_required = PkgVersion("2.5") + elif not _flash_attn_2_5_plus: + logger.debug( + "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" + ) + use_flash_attention_2 = False # Filter: Head dimension if head_dim_qk != head_dim_v: - if use_flash_attention_2 and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention 2 as it does not support MLA.") - use_flash_attention_2 = False - if use_flash_attention_3 and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 as it does not support MLA.") - use_flash_attention_3 = False + if (use_flash_attention_2 and _flash_attn_is_installed) or ( + use_flash_attention_3 and _flash_attn_is_installed + ): + logger.debug("Disabling FlashAttention as it does not support MLA.") + use_flash_attention = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 @@ -600,15 +592,8 @@ def get_attention_backend( use_flash_attention_2 = False if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): if _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 due to head_dim > 128") + logger.debug("Disabling FlashAttention 3 for head_dim > 128") use_flash_attention_3 = False - qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") - if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": - logger.debug( - "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", - qkv_layout, - ) - use_fused_attention = False # Filter: QKV layout if qkv_format == "thd": @@ -648,36 +633,28 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and (use_flash_attention_2 or use_flash_attention_3): - if fp8 and fp8_meta["recipe"].fp8_dpa: - if _flash_attn_is_installed or _flash_attn_3_is_installed: + if _flash_attn_is_installed or _flash_attn_3_is_installed: + if fp8 and fp8_meta["recipe"].fp8_dpa: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) - use_flash_attention = False - if "bottom_right" in attn_mask_type: - if _flash_attn_is_installed or _flash_attn_3_is_installed: + if "bottom_right" in attn_mask_type: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal_bottom_right masking" ) - use_flash_attention = False - elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - if _flash_attn_is_installed or _flash_attn_3_is_installed: + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal masking for cross-attention" ) - use_flash_attention = False - elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - if _flash_attn_is_installed or _flash_attn_3_is_installed: + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: logger.debug( "Disabling FlashAttention as it does not support context parallelism with bias" " type of %s", core_attention_bias_type, ) - use_flash_attention = False - elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - if _flash_attn_is_installed or _flash_attn_3_is_installed: + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " attention bias for THD format" @@ -735,36 +712,25 @@ def get_attention_backend( # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": - if use_flash_attention_2 and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention 2 for arbitrary mask") - use_flash_attention_2 = False - if use_flash_attention_3 and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 for arbitrary mask") - use_flash_attention_3 = False + if (use_flash_attention_2 and _flash_attn_is_installed) or ( + use_flash_attention_3 and _flash_attn_is_installed + ): + logger.debug("Disabling FlashAttention for arbitrary mask") + use_flash_attention = False if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False if ( - use_flash_attention_3 + (use_flash_attention_2 or use_flash_attention_3) and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): logger.warning( - "Disabling FlashAttention 3 as it only supports bottom-right-diagonal causal mask." + "Disabling FlashAttention as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1 (our minimum supported version). See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) - use_flash_attention_3 = False - if ( - use_flash_attention_2 - and attn_mask_type in ["causal", "padding_causal"] - and max_seqlen_q != max_seqlen_kv - ): - if _flash_attn_2_1_plus: - logger.warning( - "Disabling FlashAttention 2 as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention_2 = False + use_flash_attention = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -829,12 +795,9 @@ def get_attention_backend( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None ): - if use_flash_attention_2 and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention 2 for pre/post_scale_bias") - use_flash_attention_2 = False - if use_flash_attention_3 and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 for pre/post_scale_bias") - use_flash_attention_3 = False + if (use_flash_attention_2 and _flash_attn_is_installed) or (use_flash_attention_3 and _flash_attn_3_is_installed): + logger.debug("Disabling FlashAttention for pre/post_scale_bias") + use_flash_attention = False fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_shape = core_attention_bias_shape @@ -963,7 +926,7 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for determinism reasons") use_fused_attention = False - # use_flash_attention may have been used for both FAv2 and FAv3 above + # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 use_flash_attention_3 = use_flash_attention and use_flash_attention_3 @@ -993,16 +956,20 @@ def get_attention_backend( use_flash_attention_3 = False use_flash_attention = use_flash_attention_2 or use_flash_attention_3 available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] - # use FAv3 when both are present if use_flash_attention_2: flash_attention_backend = _flash_attn_version if use_flash_attention_3: flash_attention_backend = _flash_attn_3_version logger.debug( - "Available backends = {FlashAttention=%s, FusedAttention=%s%s," + "Available backends = {FlashAttention=%s%s, FusedAttention=%s%s," " UnfusedDotProductAttention=%s}", bool(available_backends[0]), + ( + f" ({str(flash_attention_backend)})" + if flash_attention_backend is not None + else "" + ), bool(available_backends[1]), ( f" (sub-backend {int(fused_attention_backend)})" @@ -1028,7 +995,7 @@ def get_attention_backend( use_unfused_attention = False selected_backend = "NoBackend" if use_flash_attention: - selected_backend = "FlashAttention" + selected_backend = "FlashAttention ({str(flash_attention_backend)})" elif use_fused_attention: selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" elif use_unfused_attention: @@ -1894,7 +1861,7 @@ def forward( cp_global_ranks, cp_stream, quantizers, - use_fa_v3, + use_flash_attn_3, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") @@ -2052,16 +2019,12 @@ def forward( if use_fused_attention: softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) else: - softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or use_fa_v3 + softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or use_flash_attn_3 flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if use_fa_v3: - # if qkv_format == "thd": - # flash_attn_fwd = _flash_attn_varlen_fwd_v3 - # else: - # flash_attn_fwd = _flash_attn_fwd_v3 + if use_flash_attn_3: flash_attn_fwd = ( _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment ) @@ -2073,7 +2036,7 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or use_fa_v3: + if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or use_flash_attn_3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 @@ -2253,12 +2216,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[3] elif i <= rank: if pad_between_seqs_q: @@ -2363,7 +2326,7 @@ def forward( max_seqlen_q, max_seqlen_kv // 2, ] - if use_fa_v3 or ( + if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) @@ -2389,12 +2352,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: @@ -2508,7 +2471,7 @@ def forward( max_seqlen_q // 2, max_seqlen_kv, ] - if use_fa_v3 or ( + if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) @@ -2534,12 +2497,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: @@ -2648,12 +2611,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[3] if i > 0: @@ -2842,7 +2805,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - ctx.use_fa_v3 = use_fa_v3 + ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") return out_ret @@ -2851,7 +2814,7 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") - use_fa_v3 = ctx.use_fa_v3 + use_flash_attn_3 = ctx.use_flash_attn_3 cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a @@ -3015,11 +2978,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if use_fa_v3: - # if ctx.qkv_format == "thd": - # flash_attn_bwd = _flash_attn_varlen_bwd_v3 - # else: - # flash_attn_bwd = _flash_attn_bwd_v3 + if use_flash_attn_3: flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) @@ -3168,12 +3127,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, 0) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 - if not use_fa_v3: + if not use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3283,12 +3242,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) if _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not use_fa_v3: + if not use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3400,12 +3359,12 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not use_fa_v3: + if not use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3494,12 +3453,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not use_fa_v3: + if not use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout, @@ -3808,7 +3767,7 @@ def forward( window_size, cp_group, cp_stream, - use_fa_v3, + use_flash_attn_3, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -3834,11 +3793,7 @@ def forward( flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if use_fa_v3: - # if qkv_format == "thd": - # flash_attn_fwd = _flash_attn_varlen_fwd_v3 - # else: - # flash_attn_fwd = _flash_attn_fwd_v3 + if use_flash_attn_3: flash_attn_fwd = _flash_attn_fwd_v3 else: if qkv_format == "thd": @@ -3961,7 +3916,7 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -3977,12 +3932,12 @@ def forward( if not _flash_attn_2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not use_fa_v3: + if not use_flash_attn_3: rng_states[i] = fa_outputs[3] if i > 0: @@ -4027,7 +3982,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - ctx.use_fa_v3 = use_fa_v3 + ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") return out @@ -4035,7 +3990,7 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") - use_fa_v3 = ctx.use_fa_v3 + use_flash_attn_3 = ctx.use_flash_attn_3 cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) @@ -4084,11 +4039,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if use_fa_v3: - # if ctx.qkv_format == "thd": - # flash_attn_bwd = _flash_attn_varlen_bwd_v3 - # else: - # flash_attn_bwd = _flash_attn_bwd_v3 + if use_flash_attn_3: flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: @@ -4158,7 +4109,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, max_seqlen_kv, ] - if not use_fa_v3: + if not use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] @@ -4279,7 +4230,7 @@ def forward( cp_group, cp_stream, quantizers, - use_fa_v3, + use_flash_attn_3, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -4304,11 +4255,7 @@ def forward( flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if use_fa_v3: - # if qkv_format == "thd": - # flash_attn_fwd = _flash_attn_varlen_fwd_v3 - # else: - # flash_attn_fwd = _flash_attn_fwd_v3 + if use_flash_attn_3: flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = window_size else: @@ -4318,7 +4265,7 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size[0] @@ -4441,10 +4388,10 @@ def forward( ) if not _flash_attn_2_7_0_plus: out, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not use_fa_v3 else None + rng_state = fa_outputs[7] if not use_flash_attn_3 else None else: out, softmax_lse = fa_outputs[0], fa_outputs[1] - rng_state = fa_outputs[3] if not use_fa_v3 else None + rng_state = fa_outputs[3] if not use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) @@ -4527,7 +4474,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - ctx.use_fa_v3 = use_fa_v3 + ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") return out_ret @@ -4535,7 +4482,7 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") - use_fa_v3 = ctx.use_fa_v3 + use_flash_attn_3 = ctx.use_flash_attn_3 cp_size = get_distributed_world_size(ctx.cp_group) ( @@ -4617,11 +4564,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if use_fa_v3: - # if ctx.qkv_format == "thd": - # flash_attn_bwd = _flash_attn_varlen_bwd_v3 - # else: - # flash_attn_bwd = _flash_attn_bwd_v3 + if use_flash_attn_3: flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) @@ -4633,7 +4576,7 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if use_fa_v3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = ctx.window_size elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = ctx.window_size[0] @@ -4709,7 +4652,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if not use_fa_v3: + if not use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( dout, @@ -4807,7 +4750,7 @@ def attn_forward_func_with_cp( fp8=False, fp8_meta=None, quantizers=None, - use_fa_v3=False, + use_flash_attn_3=False, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -4875,15 +4818,15 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, use_fa_v3] + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, use_flash_attn_3] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_fa_v3] + args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_fa_v3] + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -5299,7 +5242,6 @@ def forward( # get q_format and kv_format for training and inference qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) - if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) @@ -5590,16 +5532,27 @@ def get_qkv_layout( Returns ---------- qkv_layout: str - Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five - memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk - of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means - `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v` - are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and - `v = kv[:,:,:,1,:]`. + Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and + `kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that + paged KV caching is in play. A few examples of the layouts are as follows. + + (1) `sb3hd` means `q`, `k`, `v` are created as one chunk of memory and that they are + interleaved in the `2`nd dimension. (2) `sbhd_sbh2d` means `q` and `kv` are created in + two chunks and that `q` itself is contiguous and `k`, `v` are interleaved with each other + in the `3`rd dimension, `k = kv[:,:,:,0,:]` and `v = kv[:,:,:,1,:]`. `q_format` and + `kv_format` in this case are still both `sbhd`. (3) `paged_kv_thd_bshd_bshd` means `q` is + created in `thd` and `k`, `v` are in `sbhd`. This is likely due to the cache format in + paged KV caching. + Mapping: - `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} - `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} + `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`, `paged_kv_sbhd_sbhd_sbhd`} + `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`, `paged_kv_bshd_bshd_bshd`} `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + `sbhd_2bshd`: {`sbhd_bshd_bshd`, `paged_kv_sbhd_bshd_bshd`} + `bshd_2sbhd`: {`bshd_sbhd_sbhd`, `paged_kv_bshd_sbhd_sbhd`} + `thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`} + `thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`} + q: torch.Tensor Query tensor. It may be different from input `q` as we try to fit tensors to a supported layout. @@ -5615,7 +5568,6 @@ def get_qkv_layout( Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. """ - # q, k, v = [x.contiguous() for x in [q, k, v]] check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" if "_2" in qkv_format: @@ -5905,7 +5857,7 @@ def forward( x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - # get accurate batch_size, max_seqlen and cu_seqlens + # get batch_size, max_seqlen and cu_seqlens batch_size, context_len = None, None if inference_params is None: if qkv_format in ["sbhd", "bshd"]: @@ -5986,7 +5938,8 @@ def forward( batch_size, context_len = query_layer.shape[:2] cu_seqlens_q = cu_seqlens_q[: batch_size + 1] cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] - # convert from bshd to thd_2bshd + # convert from bshd to thd_2bshd for flash_attn_varlen_func/_with_kvcache; + # kernel assumes tensor is contiguous if isinstance(query_layer, Float8Tensor): query_layer._data = tex.convert_bshd_to_thd( query_layer._data, @@ -6003,9 +5956,9 @@ def forward( batch_size * context_len, ) - use_fa_v3 = False + use_flash_attn_3 = False if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): - use_fa_v3 = True + use_flash_attn_3 = True if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -6035,7 +5988,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, - use_fa_v3=use_fa_v3, + use_flash_attn_3=use_flash_attn_3, ) else: @@ -6052,32 +6005,32 @@ def forward( # ---------------------------------------------------------------------- # FA v2 | flash_attn_func | bshd/sbhd + not padding # | flash_attn_varlen_func | bshd/sbhd + padding - # | | thd + padding + not pad_between_seqs + # | | thd + padding # | | KV cache (not-paged/paged), i.e. # | | bshd/sbhd/thd + padding # FA v3 | flash_attn_func | bshd/sbhd + not padding # | flash_attn_varlen_func | bshd/sbhd + padding - # | | thd + padding + not pad_between_seqs + # | | thd + padding # | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e. # | | bshd/sbhd/thd + padding fa_optional_forward_args_thd = [] if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: func = ( - flash_attn_func if not use_fa_v3 else flash_attn_func_v3 + flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3 ) # pylint: disable=possibly-used-before-assignment else: - if not use_fa_v3: + if not use_flash_attn_3: func = flash_attn_varlen_func elif inference_params is None: func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment else: func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment - if not use_fa_v3 or inference_params is None: + if not use_flash_attn_3 or inference_params is None: fa_optional_forward_args_thd.append(cu_seqlens_q) fa_optional_forward_args_thd.append(cu_seqlens_kv) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) - if not use_fa_v3: + if not use_flash_attn_3: fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: fa_optional_forward_kwargs["window_size"] = window_size @@ -6086,11 +6039,11 @@ def forward( if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic if inference_params is not None: - # use block_table to support thd_2bshd format for non-paged + # use block_table kwarg to support thd_2bshd for non-paged fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged - else inference_params.cache_manager.batch_indices_post.unsqueeze(1)[ + else inference_params.cache_manager.batch_indices_post_step.unsqueeze(1)[ :batch_size ] ) @@ -6114,6 +6067,7 @@ def forward( fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens + # flash_attn_with_kvcache accepts thd_2bshd for non-paged if inference_params.is_paged: fa_3_optional_forward_kwargs["page_table"] = ( inference_params.cache_manager.page_table[:batch_size] @@ -6191,7 +6145,7 @@ def convert_to_torch_float8(tensor, dtype): output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) elif qkv_format in ["bshd", "sbhd_2bshd"]: # all KV caching cases use thd_2bshd for calculation - # convert results from thd_2bshd back to bshd + # convert results back to bshd from thd_2bshd if isinstance(query_layer, Float8Tensor): output._data = tex.convert_thd_to_bshd( output._data, @@ -6793,8 +6747,6 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_kv_padded: Optional[torch.Tensor] = None, - page_table_k: Optional[torch.Tensor] = None, - page_table_v: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", @@ -6839,6 +6791,7 @@ def forward( # get q_format and kv_format for training and inference qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + page_table = None if inference_params is None: if qkv_format in ["sbhd", "bshd"]: if qkv_format == "sbhd": @@ -6887,6 +6840,8 @@ def forward( and cu_seqlens_q is not None and cu_seqlens_kv is not None ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" + elif inference_params.is_paged: + page_table = inference_params.cache_manager.page_table if (q_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_q_padded is None: cu_seqlens_q_padded = cu_seqlens_q @@ -6962,8 +6917,8 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - page_table_k, - page_table_v, + page_table, + page_table, query_layer, key_layer, value_layer, @@ -7672,7 +7627,7 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - # retrieve tokens from KV cache in inference + # update KV cache and retrieve saved tokens from cache for inference page_table = None if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -7708,7 +7663,7 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None - # get accurate qkv_layout + # get qkv_layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( qkv_layout, @@ -7840,7 +7795,7 @@ def forward( and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) ) - # gather attention params for get available attention backends + # gather attention params for get_attention_backend attention_params = AttentionParams( qkv_type=type(query_layer), qkv_dtype=query_layer.dtype, @@ -7869,7 +7824,7 @@ def forward( fp8_meta=self.fp8_meta, inference_params=inference_params, ) - global _attention_backends # , _use_flash_attn_3 + global _attention_backends if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -7877,7 +7832,6 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: - # _use_flash_attn_3 = _flash_attn_3_is_installed ( use_flash_attention, flash_attention_backend, @@ -7905,7 +7859,7 @@ def forward( fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] - # if no backend is available, raise exception + # raise exception if no backend is available if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: raise ValueError("No dot product attention support for the provided inputs!") @@ -7969,8 +7923,6 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - page_table_k=page_table, - page_table_v=page_table, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, @@ -7997,8 +7949,6 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - page_table_k=page_table, - page_table_v=page_table, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, @@ -8770,9 +8720,7 @@ def forward( elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) else: - raise ValueError( - f"qkv_format={self.qkv_format} is not supported for KV caching and RoPE." - ) + raise ValueError(f"qkv_format={self.qkv_format} not supported for KV caching and RoPE.") sequence_start = inference_params.get_seqlens_pre_step() # sequence_start = inference_params.seqlens[0] diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 00221e0bab..da82120f4a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1113,15 +1113,19 @@ at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) /*************************************************************************************************** * KV Cache: Copy new KV tokens to the KV cache - * 1. new_k and new_v are in qkv_format, and k_cache and v_cache are in 'bshd' format - * 2. cu_new_lens and cu_cached_lens are in shape, [b + 1], and cu_cached_lens are the lens after current step - * 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1], - * max_pages_per_seq = 1. Set is_non_paged = True/False accordingly. - * 4. is_non_paged = True re-indexes the cache based on the page_table, i.e. page_table = - * [[0], [3], [1], [2]] will rearrange the cache to be [[0], [1], [1], [2]]. - * 5. k_cache and v_cache should have the same page_table - * 6. For qkv_format = thd, we assume there is no padding between sequences in new_k and new_v, - * e.g. new_k = [a a a b b c], not new_k = [a a a 0..0 b b 0..0 c 0..0]. + * 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format + * 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens + * in current step + * 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and + * max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged. + * Set is_non_paged = True/False to indicate as such. + * 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2] + * becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged. + * batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is + * preserved for the next layer in the same iteration. + * 5. Only supports same page_table for k_cache and v_cache + * 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens + * between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0]. **************************************************************************************************/ template diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh index bd585618bb..e79690d215 100644 --- a/transformer_engine/pytorch/csrc/kv_cache.cuh +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -11,6 +11,8 @@ namespace fused_attn { template __global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, int b, int max_seq_len, int h, int d) { + // tensor: thd; new_tensor: bshd + // cu_seqlens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d; int thd_offset = cu_seqlens[batch_idx] * h * d; @@ -26,6 +28,8 @@ __global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tenso template __global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, int b, int max_seq_len, int h, int d) { + // tensor: bshd; new_tensor: thd + // cu_seqlens: [b + 1] for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]; int num_elts = seqlen * h * d; @@ -69,11 +73,6 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in } } } - // if (blockIdx.x == 0) { - // for (int batch_idx = threadIdx.x; batch_idx < b; batch_idx++) { - // batch_indices[batch_idx] = batch_idx; - // } - // } } template diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index d5b3035f5c..40797631be 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -55,50 +55,44 @@ def step( class InferenceParams: """ - Inference parameters that are passed to the main model in order - to efficiently cache previous tokens and reuse them for the current - inference iteration. - - A typical KV caching workflow is as follows.:: - - modules = [TransformerLayer() for _ in range(num_layers)] - model = torch.nn.Sequential(*modules) - inference_params = InferenceParams(max_batch_size, max_seqlen_kv, ...) - for i in range(inference_iterations): - # seq_ids = [0, 2, 3] - # step_lens = [10, 1, 1] - # step_dict = OrderedDict(zip(seq_ids, step_lens)) + KV caching mechanism in inference. The memory allocation of the caches, and the copying of + new tokens to the cache take place at the following locations in TransformerLayer.:: + + class TransformerLayer: + class MultiHeadAttention: + if self.layer_number not in inference_params.cache_manager.cache: + inference_params.allocate_memory(self.layer_number) + class DotProductAttention: + if inference_params is not None: + k_cache, v_cache, new_qkv_format = inference_params.step( + new_k, new_v, qkv_format) + output = attention(new_q, k_cache, v_cache, new_qkv_format) + + allocate_memory() can be called independently if needed. step() takes 'bshd', 'sbhd' and 'thd' + formats and converts new_k and new_v to 'bshd' in both NonPagedKVCacheManager and PagedKVCacheManager. + Since new_q's format is unchanged, the returned new_qkv_format is 'bshd', 'sbhd_2bshd' and 'thd_2bshd', + respectively. A standard workflow for using InferenceParams to cache KV tokens, is as follows.:: + + model = [TransformerLayer() for _ in range(num_layers)] + # initialize InferenceParams, for example, with PagedKVCacheManager + inference_params = InferenceParams(..., is_paged=True) + # inference iterations + for i in range(num_iters): + # get step info, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1] + step_dict = OrderedDict(zip(seq_ids, step_lens)) + # update inference_params state inference_params.pre_step(step_dict) output = model( ..., - inference_params=inference_params, attn_mask_type="padding_causal", + cu_seqlens_q=cu_seqlens_new_q, + cu_seqlens_kv=cu_seqlens_new_kv, + inference_params=inference_params, ) - # assume qkv_format = "bshd" - output = output[:,step_dict.values()] - - - The memory allocation and copies of the new KV tokens to KV cache take place - in the following locations.:: - - class TransformerLayer: - class MultiHeadAttention: - if self.layer_number not in inference_params: - inference_params.allocate_memory(self.layer_number) - class DotProductAttention: - if inference_params is not None: - k_cache, v_cache, new_qkv_format = inference_params.step( - new_k, new_v, qkv_format) - output = attention(new_q, k_cache, v_cache, new_qkv_format) - - InferenceParams supports cache_qkv_format = "bshd" only, and the step() function may - change qkv_format depending on the attention backend. - - backend | qkv_format | new_qkv_format - ------------------------------------------------------------------------------------ - FusedAttention | {bshd, sbhd, thd} | {bshd_2bshd, sbhd_2bshd, thd_2bshd} - FlashAttention | {bshd, sbhd, thd} | {bshd, sbhd, thd} - UnfusedDotProductAttention | {bshd, sbhd, thd} | {bshd, sbhd, bshd} + # get inference tokens based on qkv_format + # "bshd": output = output[:,step_dict.values()-1] + # "sbhd": output = output[step_dict.values()-1,:] + # "thd" : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1 Parameters @@ -151,9 +145,6 @@ def __init__( self.dtype = dtype self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k self.is_paged = is_paged - _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) - _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) - _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) if not self.is_paged: cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager @@ -167,6 +158,7 @@ def __init__( ) else: assert page_size is not None, "Paged KV cache requires page_size is not None." + self.page_size = page_size assert ( max_seqlen_kv % page_size == 0 ), "Paged KV cache requires max_seqlen_kv % page_size = 0." @@ -174,8 +166,6 @@ def __init__( assert ( total_num_pages == self.max_batch_size * max_pages_per_seq ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." - self.page_size = page_size - self.max_seqlen_kv = max_seqlen_kv self.total_num_pages = total_num_pages cls = cache_manager if cache_manager is not None else PagedKVCacheManager @@ -194,7 +184,6 @@ def __init__( assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" self.max_ctx_len = max_ctx_len - # NonPagedKVCacheManager and PagedKVCacheManager only support 'bshd' cache self.cache_qkv_format = "bshd" self.input_qkv_format = qkv_format if self.input_qkv_format == self.cache_qkv_format: @@ -202,9 +191,8 @@ def __init__( else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - self.sequences_pre = OrderedDict() + self.sequences_pre_step = OrderedDict() self.sequences = OrderedDict() - self.step_dict = OrderedDict() self.batch_size = 0 self.cu_seqlens_q = torch.zeros( @@ -261,37 +249,33 @@ def pre_step( step_dict: OrderedDict, ): """Update tracked sequences and prepare for step()""" - self.step_dict = step_dict self.batch_size = len(step_dict) - self.total_tokens = sum(step_dict.values()) self.sequences = self.cache_manager.pre_step(step_dict) - self.sequences_pre = OrderedDict() + # track the pre-step seqlens for the next layer in the model + self.sequences_pre_step = OrderedDict() for k, v in self.sequences.items(): - self.sequences_pre[k] = v - self.step_dict[k] + self.sequences_pre_step[k] = v - step_dict[k] - actual_batch_size = len(step_dict) seqlens_q = list(step_dict.values()) - cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, actual_batch_size + 1)] - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - actual_batch_size) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu")) - seq_lens = list(self.sequences.values()) - cu_seqlens_kv = [0] + [sum(seq_lens[:i]) for i in range(1, actual_batch_size + 1)] + seqlens_kv = list(self.sequences.values()) + cu_seqlens_kv = [0] + [sum(seqlens_kv[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( - self.max_batch_size - actual_batch_size + self.max_batch_size - self.batch_size ) self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu")) def get_seqlens_pre_step(self): - """Get cached sequence lengths for current iteration before adding step_dict.values""" - return torch.Tensor(list(self.sequences_pre.values())).to(dtype=torch.int32, device="cpu") + """Get cached sequence lengths before the stepping""" + return torch.Tensor(list(self.sequences_pre_step.values())).to(dtype=torch.int32, device="cpu") def convert_paged_to_nonpaged(self, layer_number: int): """ - Convert k_cache and v_cache from paged to non-paged format. This is used by the - UnfusedDotProductAttention backend. Both k_cache and v_cache are assumed to be - in 'bshd' format. + Convert k_cache and v_cache from paged to non-paged format. Parameters ---------- @@ -308,7 +292,6 @@ def convert_paged_to_nonpaged(self, layer_number: int): k_cache, v_cache = self.cache_manager.cache[layer_number] page_table = self.cache_manager.page_table batch_size = page_table.shape[0] - actual_batch_size = len(self.step_dict) new_k_cache = rearrange( k_cache[page_table.flatten()], "(b npages) page_size ... -> b (npages page_size) ...", @@ -320,8 +303,8 @@ def convert_paged_to_nonpaged(self, layer_number: int): b=batch_size, ) - new_k_cache = new_k_cache[:actual_batch_size].contiguous() - new_v_cache = new_v_cache[:actual_batch_size].contiguous() + new_k_cache = new_k_cache[:self.batch_size].contiguous() + new_v_cache = new_v_cache[:self.batch_size].contiguous() return new_k_cache, new_v_cache @@ -333,7 +316,7 @@ def step( qkv_format: str, ): """ - Copy the new KV tokens to the KV cache and reshape Q if necessary. + Copy new KV tokens to the cache. Parameters ---------- @@ -353,7 +336,7 @@ def step( v_cache: torch.Tensor Full value tensor containing both previous and current value tokens page_table: torch.Tensor - Page table for paged KV cache, [batch_size, max_pages_per_seq]. None for non-paged KV cache + Page table for paged KV cache, [batch_size, max_pages_per_seq]. None for non-paged KV cache. cu_seqlens_q: torch.Tensor Updated cumulative sequence lengths for query, [batch_size + 1] cu_seqlens_kv: torch.Tensor @@ -363,7 +346,7 @@ def step( max_seqlen_kv: int Update maximum sequence length for key and value qkv_format: str - Updated qkv_format, e.g. the input 'thd' format may become 'thd_2bshd' after step() + Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step() """ self.input_qkv_format = qkv_format if self.input_qkv_format == self.cache_qkv_format: @@ -422,8 +405,8 @@ def __init__( dtype=torch.int32, device=torch.cuda.current_device(), ) - # always in [0, ..., b-1] fashion due to reindexing - self.batch_indices_post = torch.range( + # after re-indexing, batch indices are always [0, ..., b-1] + self.batch_indices_post_step = torch.range( 0, self.max_batch_size - 1, dtype=torch.int32, @@ -456,9 +439,9 @@ def pre_step( ): """Update tracked sequences and prepare for step()""" # Track unfinished sequences' indices in the batch, e.g. - # at t-1, seq_ids = [0, 1, 2, 3], and at t, seq_ids = [0, 2, 3], because seq_id 1 finished - # batch_indices = [0, 2, 3, 1] is used in step() to re-index k_cache and v_cache so that - # they are contiguous and match the sequence indexing in q. + # at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished + # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that + # they are contiguous and match the indexing in q prev_batch_size = len(self.sequences) unfinished_seqs = self.sequences.keys() & step_dict.keys() finished_seqs = self.sequences.keys() - unfinished_seqs From 3618785b1ac6d02d0c838838a650a65086715bb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Mar 2025 23:55:05 +0000 Subject: [PATCH 162/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 16 ++++---- transformer_engine/pytorch/attention.py | 40 +++++++++++-------- transformer_engine/pytorch/inference.py | 8 ++-- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b5f2a4386a..6ede42a881 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -273,27 +273,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - max_seqlen_q <= max_seqlen_kv)) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) (cudnn_runtime_version >= 90600 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0) || + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} - // for any q_format/kv_format, and paged/non-paged + // for any q_format/kv_format, and paged/non-paged (cudnn_runtime_version >= 90700 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv)))) && + max_seqlen_q <= max_seqlen_kv)))) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 52076c9a90..37160c77e1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -537,9 +537,7 @@ def get_attention_backend( use_fused_attention = False else: if q_format == "thd" and pad_between_seqs: - logger.debug( - "Disabling all backends for pad_between_seqs = True and KV caching" - ) + logger.debug("Disabling all backends for pad_between_seqs = True and KV caching") use_flash_attention = False use_fused_attention = False use_unfused_attention = False @@ -795,7 +793,9 @@ def get_attention_backend( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None ): - if (use_flash_attention_2 and _flash_attn_is_installed) or (use_flash_attention_3 and _flash_attn_3_is_installed): + if (use_flash_attention_2 and _flash_attn_is_installed) or ( + use_flash_attention_3 and _flash_attn_3_is_installed + ): logger.debug("Disabling FlashAttention for pre/post_scale_bias") use_flash_attention = False @@ -965,11 +965,7 @@ def get_attention_backend( "Available backends = {FlashAttention=%s%s, FusedAttention=%s%s," " UnfusedDotProductAttention=%s}", bool(available_backends[0]), - ( - f" ({str(flash_attention_backend)})" - if flash_attention_backend is not None - else "" - ), + (f" ({str(flash_attention_backend)})" if flash_attention_backend is not None else ""), bool(available_backends[1]), ( f" (sub-backend {int(fused_attention_backend)})" @@ -3127,7 +3123,9 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, 0) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -3242,7 +3240,9 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) if _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -3359,7 +3359,9 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -3916,7 +3918,9 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_forward_kwargs["window_size"] = window_size_per_step[i] elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -6043,9 +6047,9 @@ def forward( fa_optional_forward_kwargs["block_table"] = ( inference_params.cache_manager.page_table[:batch_size] if inference_params.is_paged - else inference_params.cache_manager.batch_indices_post_step.unsqueeze(1)[ - :batch_size - ] + else inference_params.cache_manager.batch_indices_post_step.unsqueeze( + 1 + )[:batch_size] ) output = func( query_layer, @@ -8720,7 +8724,9 @@ def forward( elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) else: - raise ValueError(f"qkv_format={self.qkv_format} not supported for KV caching and RoPE.") + raise ValueError( + f"qkv_format={self.qkv_format} not supported for KV caching and RoPE." + ) sequence_start = inference_params.get_seqlens_pre_step() # sequence_start = inference_params.seqlens[0] diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 40797631be..2942cf90ca 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -271,7 +271,9 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to(dtype=torch.int32, device="cpu") + return torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -303,8 +305,8 @@ def convert_paged_to_nonpaged(self, layer_number: int): b=batch_size, ) - new_k_cache = new_k_cache[:self.batch_size].contiguous() - new_v_cache = new_v_cache[:self.batch_size].contiguous() + new_k_cache = new_k_cache[: self.batch_size].contiguous() + new_v_cache = new_v_cache[: self.batch_size].contiguous() return new_k_cache, new_v_cache From bc4c452c64a19481bbd08826b34e8c5fadfa1890 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Tue, 4 Mar 2025 00:58:47 +0100 Subject: [PATCH 163/239] [common] Removed tensor boundary checks in MXFP8 kernels (#1519) Added constexpr checks of tensor boundaries Signed-off-by: Oleg Goncharov --- transformer_engine/common/util/cast_kernels.cuh | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index d1ede8d98d..b4b86fe708 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -261,7 +261,13 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } } in_compute[j] = elt; - if (!out_of_bounds) { + + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } } @@ -320,7 +326,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } } in_compute[i] = elt; - if (!out_of_bounds) { + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this amax = fmaxf(amax, fabsf(elt)); } } From 90d5d45d69c94840dd8ab89761c197f277b0174b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 4 Mar 2025 10:05:29 +0530 Subject: [PATCH 164/239] Add sanity test for lightning-thunder integration (#1531) Signed-off-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_unittest/test.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index fe36b33384..17fb4d1827 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -4,8 +4,9 @@ : ${TE_PATH:=/opt/transformerengine} +: ${LIGHTNING_THUNDER_PATH:=/opt/pytorch/lightning-thunder} -pip install pytest==8.2.1 +pip install pytest==8.2.1 pytest-benchmark==5.1.0 FAIL=0 @@ -24,5 +25,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 +pytest -v -s ${LIGHTNING_THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py exit $FAIL From cbb96f2b0c1c36f55b423bc34db884bb4210d29f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Mon, 3 Mar 2025 21:16:58 -0800 Subject: [PATCH 165/239] Export only necessary symbols from libtransformer_engine.so (#1511) * Expose only required symbols from libtransformer_engine.so during linking for pytorch Signed-off-by: Kshitij Janardan Lakhani * Augment libtransformer_engine.version for jax compatibility Signed-off-by: Kshitij Janardan Lakhani * Augment the libtransformer_engine.version to ensure compatibility with CPP tests Remove getenv from the .version file Combine system.cpp and system.h Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Nit: Remove commented code for not including common.h Signed-off-by: Kshitij Janardan Lakhani * Replace explicit getenv instantiations with a helper template Use filesystem calls in file_exists() Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert comment to falsy instead of false Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> --------- Signed-off-by: Kshitij Janardan Lakhani Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/common/CMakeLists.txt | 1 - .../common/libtransformer_engine.version | 20 ++++- transformer_engine/common/util/system.cpp | 76 ----------------- transformer_engine/common/util/system.h | 83 +++++++++++++++++-- .../jax/csrc/extensions/attention.cpp | 14 ++-- transformer_engine/jax/csrc/extensions/misc.h | 8 ++ 6 files changed, 111 insertions(+), 91 deletions(-) delete mode 100644 transformer_engine/common/util/system.cpp diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index c77d230ce5..68231f6c04 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -80,7 +80,6 @@ list(APPEND transformer_engine_SOURCES util/cuda_driver.cpp util/cuda_runtime.cpp util/rtc.cpp - util/system.cpp swizzle/swizzle.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 0683ec01ea..546f7f3403 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -1,4 +1,20 @@ { - global: *nvte*; *transformer_engine*; + global: + extern "C++" { + nvte_*; + transformer_engine::cuda::sm_count*; + transformer_engine::cuda::sm_arch*; + transformer_engine::cuda::supports_multicast*; + transformer_engine::cuda::stream_priority_range*; + transformer_engine::cuda::current_device*; + transformer_engine::cuda_driver::get_symbol*; + transformer_engine::ubuf_built_with_mpi*; + *transformer_engine::rtc*; + transformer_engine::nvte_cudnn_handle_init*; + transformer_engine::typeToSize*; + *transformer_engine::CommOverlapBase*; + *transformer_engine::CommOverlapP2PBase*; + *transformer_engine::CommOverlapCore* + }; local: *; -}; +}; \ No newline at end of file diff --git a/transformer_engine/common/util/system.cpp b/transformer_engine/common/util/system.cpp deleted file mode 100644 index 502dced9fc..0000000000 --- a/transformer_engine/common/util/system.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../util/system.h" - -#include -#include -#include -#include -#include -#include - -#include "../common.h" - -namespace transformer_engine { - -namespace { - -template -inline typename std::enable_if::value, T>::type getenv_helper( - const char *variable, const T &default_value) { - // Implementation for numeric types - const char *env = std::getenv(variable); - if (env == nullptr || env[0] == '\0') { - return default_value; - } - T value; - std::istringstream iss(env); - iss >> value; - NVTE_CHECK(iss, "Invalid environment variable value"); - return value; -} - -template -inline typename std::enable_if::value, T>::type getenv_helper( - const char *variable, const T &default_value) { - // Implementation for string-like types - const char *env = std::getenv(variable); - if (env == nullptr || env[0] == '\0') { - return default_value; - } else { - return env; - } -} - -} // namespace - -#define NVTE_INSTANTIATE_GETENV(T, default_value) \ - template <> \ - T getenv(const char *variable, const T &default_value_) { \ - return getenv_helper(variable, default_value_); \ - } \ - template <> \ - T getenv(const char *variable) { \ - return getenv_helper(variable, default_value); \ - } -NVTE_INSTANTIATE_GETENV(bool, false); -NVTE_INSTANTIATE_GETENV(float, 0.f); -NVTE_INSTANTIATE_GETENV(double, 0.); -NVTE_INSTANTIATE_GETENV(int8_t, 0); -NVTE_INSTANTIATE_GETENV(int16_t, 0); -NVTE_INSTANTIATE_GETENV(int32_t, 0); -NVTE_INSTANTIATE_GETENV(int64_t, 0); -NVTE_INSTANTIATE_GETENV(uint8_t, 0); -NVTE_INSTANTIATE_GETENV(uint16_t, 0); -NVTE_INSTANTIATE_GETENV(uint32_t, 0); -NVTE_INSTANTIATE_GETENV(uint64_t, 0); -NVTE_INSTANTIATE_GETENV(std::string, std::string()); -NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path()); - -bool file_exists(const std::string &path) { return static_cast(std::ifstream(path.c_str())); } - -} // namespace transformer_engine diff --git a/transformer_engine/common/util/system.h b/transformer_engine/common/util/system.h index 71c7ef3216..5636ab5095 100644 --- a/transformer_engine/common/util/system.h +++ b/transformer_engine/common/util/system.h @@ -7,25 +7,96 @@ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_ +#include +#include +#include +#include +#include #include +#include "logging.h" + namespace transformer_engine { +namespace detail { + +/*! \brief Template specialization to get the env var for numeric data types */ +template +inline typename std::enable_if::value, T>::type getenv_helper( + const char *variable, const T &default_value) { + // Implementation for numeric types + const char *env = std::getenv(variable); + if (env == nullptr || env[0] == '\0') { + return default_value; + } + T value; + std::istringstream iss(env); + iss >> value; + NVTE_CHECK(iss, "Invalid environment variable value"); + return value; +} + +/*! \brief Template specialization to get the env var for string-like data types */ +template +inline typename std::enable_if::value, T>::type getenv_helper( + const char *variable, const T &default_value) { + // Implementation for string-like types + const char *env = std::getenv(variable); + if (env == nullptr || env[0] == '\0') { + return default_value; + } else { + return env; + } +} + +/*! \brief Template specialization to get the default values for different +* numeric data types +*/ +template +inline T getenv_default_value() { + return 0; +} + +/*! \brief Template specialization to get the default values for bool */ +template <> +inline bool getenv_default_value() { + return false; +} + +/*! \brief Template specialization to get the default values for string */ +template <> +inline std::string getenv_default_value() { + return std::string(); +} + +/*! \brief Template specialization to get the default values for filesystem +* path data type */ +template <> +inline std::filesystem::path getenv_default_value() { + return std::filesystem::path(); +} + +} // namespace detail + /*! \brief Get environment variable and convert to type * * If the environment variable is unset or empty, a falsy value is * returned. - */ +*/ template -T getenv(const char *variable); +inline T getenv(const char *variable) { + return detail::getenv_helper(variable, detail::getenv_default_value()); +} /*! \brief Get environment variable and convert to type */ template -T getenv(const char *variable, const T &default_value); +inline T getenv(const char *variable, const T &default_value) { + return detail::getenv_helper(variable, default_value); +} -/*! \brief Check if a file exists and can be read */ -bool file_exists(const std::string &path); +inline bool file_exists(const std::string &path) { + return std::filesystem::exists(path) && std::filesystem::is_regular_file(path); +} } // namespace transformer_engine - #endif // TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_ diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index bc8c2c9aeb..aa5b46d054 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -540,7 +540,8 @@ static void FusedAttnBackwardImpl( auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), + stream); } nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 @@ -558,8 +559,9 @@ static void FusedAttnBackwardImpl( auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), + stream); } nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -581,9 +583,9 @@ static void FusedAttnBackwardImpl( auto dk_tensor = TensorWrapper(dk, k_shape, dtype); auto dv_tensor = TensorWrapper(dv, v_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 7ccfc85e8e..7cb83a0f9e 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -26,5 +26,13 @@ struct Shape { std::vector MakeShapeVector(NVTEShape shape); +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + } // namespace jax } // namespace transformer_engine From e85d1806638eb0661c2933f7f14523bdd07d36f0 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 4 Mar 2025 12:43:16 -0800 Subject: [PATCH 166/239] Update list of CI users (#1535) Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index cef039f976..cd20e8b2d0 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -43,6 +43,7 @@ jobs: || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' || github.actor == 'jberchtold-nvidia' + || github.actor == 'negvet' ) steps: - name: Check if comment is issued by authorized person From f8eddcf9278416815c286aac74cbbba96bf66ad3 Mon Sep 17 00:00:00 2001 From: Nicolas Castet <26874160+nvcastet@users.noreply.github.com> Date: Tue, 4 Mar 2025 18:27:02 -0600 Subject: [PATCH 167/239] Add support for UB MNNVL (#1470) * Add support for UB MNNVL Signed-off-by: Nicolas Castet * Address review comments Signed-off-by: Nicolas Castet * Fix lint Signed-off-by: Nicolas Castet * Dlopen nvml lib since it comes with the cuda driver Signed-off-by: Nicolas Castet * Add initial copyright date Signed-off-by: Nicolas Castet --------- Signed-off-by: Nicolas Castet --- .../te_layer_with_overlap.py | 104 ++---- .../distributed/run_gemm_with_overlap.py | 19 +- .../distributed/run_layer_with_overlap.py | 24 +- transformer_engine/common/CMakeLists.txt | 1 + .../userbuffers/userbuffers-host.cpp | 314 ++++++++++-------- .../userbuffers/userbuffers.cu | 54 +-- .../userbuffers/userbuffers.h | 25 +- .../common/util/cuda_driver.cpp | 80 ----- transformer_engine/common/util/cuda_nvml.cpp | 26 ++ transformer_engine/common/util/cuda_nvml.h | 69 ++++ .../common/util/shared_lib_wrapper.h | 64 ++++ transformer_engine/pytorch/csrc/extensions.h | 3 +- .../csrc/extensions/comm_gemm_overlap.cpp | 18 +- .../pytorch/csrc/extensions/pybind.cpp | 5 +- transformer_engine/pytorch/module/base.py | 80 +---- 15 files changed, 438 insertions(+), 448 deletions(-) create mode 100644 transformer_engine/common/util/cuda_nvml.cpp create mode 100644 transformer_engine/common/util/cuda_nvml.h create mode 100644 transformer_engine/common/util/shared_lib_wrapper.h diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index d94c352401..e510df1761 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None): help="Disable the comm+GEMM overlap.", ) parser.add_argument( - "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + "--num-replicas", + type=int, + default=1, + help="Number of data-parallel model replicas per node.", + ) + parser.add_argument( + "--use-global-replica-count", + action="store_true", + default=False, + help="Treat '--num-replicas' as the total number of replicas.", ) parser.add_argument( "--tcp-init", @@ -173,13 +182,12 @@ def _train(opts): opts.tcp_init = True opts.bind_to_device = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: # TORCHELASTIC, SLURM, etc... WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + NUM_NODES = WORLD_SIZE // LOCAL_SIZE # Initialize torch.distributed global process group and get DP/TP groups @@ -214,90 +222,24 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") - # Figure out process groups for tensor- and data-parallelism (if any) - if NUM_NODES > 1: - # Create a list of world ranks on this node - hostname = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), - ) - - if ifname is not None: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - hostname = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - - hostnames = [None for _ in range(WORLD_SIZE)] - dist.all_gather_object(hostnames, hostname) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - assert len(unique_hosts) == NUM_NODES - - ranks_per_node_list = [[] for _ in range(NUM_NODES)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0 - self_node_ranks = ranks_per_node_list[self_node_idx] - - if opts.num_replicas > 1: - # Split node ranks into multiple replicas - assert len(self_node_ranks) % opts.num_replicas == 0 - tp_size = len(self_node_ranks) // opts.num_replicas - ranks_per_replica_list = [] - for node_ranks in ranks_per_node_list: - for i in range(opts.num_replicas): - start = i * tp_size - end = start + tp_size - ranks_per_replica_list.append(node_ranks[start:end]) - - self_replica_idx = -1 - for i, replica_ranks in enumerate(ranks_per_replica_list): - if WORLD_RANK in replica_ranks: - self_replica_idx = i - break - assert self_replica_idx >= 0 + total_replicas = ( + opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES + ) + tp_size = WORLD_SIZE // total_replicas - else: - # The entire node is the tensor-parallel group - ranks_per_replica_list = ranks_per_node_list - self_replica_idx = self_node_idx + if total_replicas > 1: + ranks_per_replica_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(total_replicas) + ] tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) dp_group, _ = dist.new_subgroups_by_enumeration( ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" ) - else: - if opts.num_replicas > 1: - # Mixed data- and tensor-parallelism on a single node - # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions - all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") - ranks_per_replica_tensor = all_ranks.reshape( - (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) - ) - tp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.tolist(), backend="nccl" - ) - dp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" - ) - else: - dp_group = None - tp_group = nccl_world + dp_group = None + tp_group = nccl_world tp_rank = dist.get_rank(tp_group) tp_size = dist.get_world_size(tp_group) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 9e11e07e11..4bbdd23fd6 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -180,15 +180,22 @@ def _main(opts): LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) opts.tcp_init = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: # TORCHELASTIC, SLURM, etc... WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node - assert LOCAL_SIZE <= torch.cuda.device_count() + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + + result = subprocess.run( + "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'", + capture_output=True, + text=True, + shell=True, + ) + + if result.stdout == "0": # Extra checks for non-MNNVL platforms + assert WORLD_SIZE == LOCAL_SIZE + assert LOCAL_SIZE <= torch.cuda.device_count() # Fix clock speed torch.cuda.set_device(LOCAL_RANK) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index d4a01386ee..39200775c9 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -7,6 +7,7 @@ import os import sys import socket +import subprocess import argparse import warnings import pprint @@ -209,14 +210,21 @@ def _train(opts): opts.tcp_init = True opts.bind_to_device = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - assert LOCAL_SIZE == WORLD_SIZE + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + + result = subprocess.run( + "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'", + capture_output=True, + text=True, + shell=True, + ) + + if result.stdout == "0": # Extra checks for non-MNNVL platforms + assert WORLD_SIZE == LOCAL_SIZE def dist_print(msg, src=None, end="\n", debug=False, error=False): if debug and not opts.debug: @@ -227,7 +235,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist.barrier() # Set device and initialize RNG states - torch.cuda.set_device(WORLD_RANK) + torch.cuda.set_device(LOCAL_RANK) torch.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed) @@ -312,7 +320,7 @@ def run_fwd_bwd(model, x): return out torch_rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) + cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}")) if opts.use_cuda_graphs: test_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(test_graph): @@ -329,7 +337,7 @@ def run_fwd_bwd(model, x): names.append(test_name + ".grad") torch.set_rng_state(torch_rng_state) - torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) + torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}")) if opts.use_cuda_graphs: ref_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(ref_graph): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 68231f6c04..0a2abb6e4e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES util/cast.cu util/padding.cu util/cuda_driver.cpp + util/cuda_nvml.cpp util/cuda_runtime.cpp util/rtc.cpp swizzle/swizzle.cu diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index c3453aeffe..14ff853266 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -20,6 +20,7 @@ #include #include "common/util/cuda_driver.h" +#include "common/util/cuda_nvml.h" #include "common/util/cuda_runtime.h" #include "common/util/logging.h" #include "common/util/system.h" @@ -29,7 +30,6 @@ #ifdef NVTE_UB_WITH_MPI static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD; static MPI_Comm EXT_COMM_INTRA; -static MPI_Comm EXT_COMM_INTER; #define UB_MPI_CHECK(expr) \ do { \ @@ -58,11 +58,20 @@ void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } #else #define EXT_COMM_WORLD "world" #define EXT_COMM_INTRA "intra" -#define EXT_COMM_INTER "inter" #endif #define MULTICAST_GB_TOTAL 512 +#if CUDART_VERSION < 12030 +// MNNVL: FABRIC handle support lifted from CUDA 12.3 +#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) +#define CU_IPC_HANDLE_SIZE 64 +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; +#endif + int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } #define IPCCHECK(cmd) \ @@ -82,18 +91,43 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co } \ } while (0); -int pipe_rank(communicator *comm, int step) { - int mynode = comm->myrank / comm->nvsize; - int mylocal = comm->nvrank; - int numlocal = comm->nvsize; - - int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize; - int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal; - int newnode = mynode; - newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes; - int allnodes = comm->nranks / comm->nvsize; - newnode = (allnodes + (newnode % allnodes)) % allnodes; - return newnode * numlocal + newlocal; +bool has_mnnvl_fabric(int device_id) { +#if CUDA_VERSION < 12040 + if (getenv("NVTE_UBDEBUG")) { + printf( + "TransformerEngine does not support multi-node NVLINK " + "since it was not built with CUDA version >= 12.4.\n"); + } + return false; +#else + bool mnnvl_fabric_support = false; + CUdevice dev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); + int fabric_handle_supported = 0; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &fabric_handle_supported, + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev); + if (fabric_handle_supported) { + NVTE_CALL_CHECK_CUDA_NVML(nvmlInit_v2); + nvmlDevice_t local_device; + NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetHandleByIndex_v2, device_id, &local_device); + nvmlGpuFabricInfoV_t fabricInfo = {}; + fabricInfo.version = nvmlGpuFabricInfo_v2; + fabricInfo.clusterUuid[0] = '\0'; + NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetGpuFabricInfoV, local_device, &fabricInfo); + NVTE_CALL_CHECK_CUDA_NVML(nvmlShutdown); + if (fabricInfo.state >= NVML_GPU_FABRIC_STATE_COMPLETED && fabricInfo.clusterUuid[0] != '\0') { + mnnvl_fabric_support = true; + } + } + if (getenv("NVTE_UBDEBUG")) { + if (mnnvl_fabric_support) { + printf("MNNVL NVLINK is supported on this platform.\n"); + } else { + printf("MNNVL NVLINK is not supported on this platform.\n"); + } + } + return mnnvl_fabric_support; +#endif } int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, @@ -122,10 +156,6 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, (*comm)->use_ce = 0; (*comm)->cga_size = 2; for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; - (*comm)->head = 0; - (*comm)->tail = 0; - (*comm)->active_nreqs = 0; - for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; int device_clock = 0; // 110 sec wait time by default @@ -182,29 +212,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, // ar2 has step equal to ar_nvsize int allnodes = numranks / numlocal; int nodeid = myrank / numlocal; - int datanodes = allnodes / pipenodes / tensornodes; - int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes); - (*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus); - - (*comm)->comm_inter = EXT_COMM_INTER; - (*comm)->first_node = nodeid - mynode; (*comm)->num_nodes = numnodes; (*comm)->my_node = mynode; - (*comm)->num2_nodes = tensornodes; - (*comm)->my2_node = (mynode / datanodes) % tensornodes; - (*comm)->first2_node = mynode - (*comm)->my2_node * datanodes; - - (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); - (*comm)->nblocks = 8; - (*comm)->alignblock = 1024 * 512; - (*comm)->minblock = 1024 * 2 * 1024; - (*comm)->asyncblocks = 16; - #define NBUF 2 #if CUDART_VERSION >= 12010 + bool mnnvl_fabric = has_mnnvl_fabric(cur_dev); if (!transformer_engine::getenv("UB_SKIPMC") && transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { // multicast init only for TP ops (____2 operations) @@ -215,7 +230,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, CUmulticastObjectProp mcProp = {}; mcProp.numDevices = (*comm)->ar2_nvsize; mcProp.size = (*comm)->mc_maxsize; - mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + mcProp.handleTypes = + mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; NVTE_CALL_CHECK_CUDA_DRIVER( cuMulticastGetGranularity, &gran, &mcProp, @@ -223,46 +239,78 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran; mcProp.size = mc_maxsize; (*comm)->mc_maxsize = mc_maxsize; - - // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. - // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the - // file descriptor and prevent cuMemImportFromShareableHandle() from correctly - // interpreting the file. Instead, we use Unix domain sockets for the kernel to - // recreate the correct file descriptor on every receiving rank. - int fd; - volatile uint32_t abortFlag = 0; - IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu; - ipcSocketResult_t ret = ipcSocketSuccess; - IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); - (*comm)->_barrier((*comm)->comm_world); - - if ((*comm)->ar2_nvrank == 0) { + if ((*comm)->ar2_nvrank == 0) NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp); - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), - (uint64_t)0); - for (int p = 1; p < (*comm)->ar2_nvsize; p++) { - (*comm)->_barrier((*comm)->comm_intra); - IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + if (mnnvl_fabric) { + CUmemFabricHandle *exphndl = + reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); + CUmemFabricHandle *tmphndl = + reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); + CUmemFabricHandle *exphndls; + NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle))); + if ((*comm)->ar2_nvrank == 0) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast(tmphndl), + (*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0); + for (int grp = 0; grp < (*comm)->ar_nvsize; + grp++) { // we do N broadcasts for N TP groups in NVL domain + int root = grp * (*comm)->ar2_nvsize; + + // It just needs to be a bcast but reuse existing allgather comm + (*comm)->_allgather( + reinterpret_cast(exphndls), (*comm)->nvsize * sizeof(CUmemFabricHandle), + reinterpret_cast(tmphndl), sizeof(CUmemFabricHandle), (*comm)->comm_intra); + + //save data if brodcast was from rank 0 in our group + if ((*comm)->ar2_firstgpu == root) + memcpy(exphndl, exphndls + root, sizeof(CUmemFabricHandle)); } + if ((*comm)->ar2_nvrank != 0) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &(*comm)->mc_handle, + reinterpret_cast(exphndl), CU_MEM_HANDLE_TYPE_FABRIC); + free(exphndl); + free(tmphndl); + NVTE_CHECK_CUDA(cudaFreeHost(exphndls)); } else { - for (int p = 1; p < (*comm)->ar2_nvsize; p++) { - (*comm)->_barrier((*comm)->comm_intra); - if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. + // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the + // file descriptor and prevent cuMemImportFromShareableHandle() from correctly + // interpreting the file. Instead, we use Unix domain sockets for the kernel to + // recreate the correct file descriptor on every receiving rank. + int fd; + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafeb000 + (*comm)->my_node + (*comm)->ar2_firstgpu; + ipcSocketResult_t ret = ipcSocketSuccess; + IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); + (*comm)->_barrier((*comm)->comm_world); + + if ((*comm)->ar2_nvrank == 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); + + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + } + } else { + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + } } - } - error: - if ((*comm)->ar2_nvrank != 0) { - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + error: + if ((*comm)->ar2_nvrank != 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + } + IPCCHECK(ipcSocketClose(&ipcSock)); + close(fd); } - IPCCHECK(ipcSocketClose(&ipcSock)); - close(fd); NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle, (CUdeviceptr)(*comm)->mydev); @@ -327,12 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, if (getenv("NVTE_UBDEBUG")) printf( - "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " - "%dx%d PIPE_ID %d/%d\n", + "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP x%d TPGROUP " + "%dx%d\n", myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, - (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, - (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, - pipegpus * pipenodes); + (*comm)->ar_nvrank, (*comm)->my_node, (*comm)->ar2_nvrank, (*comm)->ar_nvsize, + (*comm)->num_nodes, (*comm)->ar2_nvsize); fflush(NULL); return 0; @@ -361,43 +408,14 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks)); - // find intranode numbers and make internode communicator - char hostname[MPI_MAX_PROCESSOR_NAME]; - int namelen; - UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen)); - - char(*hostnames)[MPI_MAX_PROCESSOR_NAME] = - static_cast(malloc(numranks * MPI_MAX_PROCESSOR_NAME)); - strcpy(hostnames[myrank], hostname); // NOLINT(*) - for (int n = 0; n < numranks; n++) - UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD)); - qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp); - - int color = 0; - for (int n = 0; n < numranks; n++) { - if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++; - if (strcmp(hostname, hostnames[n]) == 0) break; - } - free(hostnames); - int mylocal, numlocal; - UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA)); + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, myrank / tensorgpus, myrank, &EXT_COMM_INTRA)); UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal)); // find internode numbers and make internode communicator NVTE_CHECK_CUDA(cudaFree(0)); - int allnodes = numranks / numlocal; - int datanodes = allnodes / pipenodes / tensornodes; - // data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1 - int datanodegroup_id = myrank / numlocal / datanodes; - // mpi communicator only needed for SHARP which is always allreduce1/data-parallel - UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank, - &EXT_COMM_INTER)); - // different rails from same group are in different subcommunicators int mynode, numnodes; - UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes)); - UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode)); // finally call the abstracted constructor with MPI info return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, @@ -447,13 +465,11 @@ void destroy_communicator(communicator *comm) { if (comm->use_mc) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle); } - free(comm->fifo); delete comm; } void destroy_communicator_mpi(communicator *comm) { #ifdef NVTE_UB_WITH_MPI - MPI_Comm_free(static_cast(&(comm->comm_inter))); MPI_Comm_free(static_cast(&(comm->comm_intra))); destroy_communicator(comm); #else @@ -472,6 +488,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * #if CUDART_VERSION >= 12010 if (comm->use_mc && alloc) { + bool mnnvl_fabric = has_mnnvl_fabric(comm->mydev); int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; void **remptrs = reinterpret_cast(malloc(nranks * sizeof(void *))); @@ -481,7 +498,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = comm->mydev; prop.requestedHandleTypes = - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // CU_MEM_HANDLE_TYPE_FABRIC; + mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; size_t granularity = 0; NVTE_CALL_CHECK_CUDA_DRIVER( @@ -507,41 +524,58 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop, (uint64_t)0); - int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), - comm->uchandles[hndl][myrank], - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), - (uint64_t)0); - - volatile uint32_t abortFlag = 0; - IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafebeef; - ipcSocketResult_t ret = ipcSocketSuccess; - - // All-gather POSIX file descriptors across local ranks - IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); - for (int p = 1; p < nranks; p++) { - int send_to = (myrank + p) % nranks; - int recv_from = (myrank + nranks - p) % nranks; - comm->_barrier(comm->comm_intra); - IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error); - IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); - } + if (mnnvl_fabric) { + CUmemFabricHandle *exphndl; + CUmemFabricHandle myhndl; + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl, + comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0); + NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle))); + comm->_allgather(reinterpret_cast(exphndl), comm->nvsize * sizeof(CUmemFabricHandle), + reinterpret_cast(&myhndl), sizeof(CUmemFabricHandle), + comm->comm_intra); + for (int p = 0; p < nranks; p++) + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(&exphndl[p]), + CU_MEM_HANDLE_TYPE_FABRIC); + NVTE_CHECK_CUDA(cudaFreeHost(exphndl)); + } else { + int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), + comm->uchandles[hndl][myrank], + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); - error: - IPCCHECK(ipcSocketClose(&ipcSock)); + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafebeef + comm->my_node; + ipcSocketResult_t ret = ipcSocketSuccess; + + // All-gather POSIX file descriptors across local ranks + IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); + for (int p = 1; p < nranks; p++) { + int send_to = (myrank + p) % nranks; + int recv_from = (myrank + nranks - p) % nranks; + comm->_barrier(comm->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, + error); + IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); + } - for (int p = 0; p < nranks; p++) { - if (p != myrank) - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], - reinterpret_cast(peerfd[p]), - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - close(peerfd[p]); - } - free(peerfd); + error: + IPCCHECK(ipcSocketClose(&ipcSock)); + for (int p = 0; p < nranks; p++) { + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(peerfd[p]), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(peerfd[p]); + } + free(peerfd); + } CUdeviceptr ptr; NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks), (size_t)0, (CUdeviceptr)0, (uint64_t)0); @@ -571,13 +605,13 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * cudaMemcpy((reinterpret_cast(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice)); free(remptrs); - comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED; + comm->memflags[hndl] = NVTE_UB_MEM_UC_CONTIG | NVTE_UB_MEM_ALLOCATED; if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset, comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/, aligned_size, (uint64_t)0); - comm->memflags[hndl] |= UB_MEM_MC_CREATED; + comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED; comm->mc_ptr[hndl] = reinterpret_cast(comm->mc_baseptr) + comm->mc_offset; comm->mc_offset += aligned_size; } else if (!comm->myrank) { diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 58de844858..1211392e40 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1682,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) + callranks_rs_oop_stride(16) callranks_rs_oop_stride(32) } void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, const int rowelements, const int colelements, @@ -1703,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) - callranks_rs_oop_stride_atomic(8) + callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16) + callranks_rs_oop_stride_atomic(32) } template @@ -1729,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) + callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32) } template @@ -1773,7 +1776,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) - callranks_rs_oop_stride_multiatomic(8) + callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16) + callranks_rs_oop_stride_multiatomic(32) } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, @@ -1793,17 +1797,17 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) } } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) } } } @@ -1840,17 +1844,17 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) } } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) } } } @@ -1873,17 +1877,21 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) + callranks_rs_oopMC(32) } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) } } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) + callranks_rs_oopMC(32) } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) } } } @@ -1915,10 +1923,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index ee808b7f9a..84defcdb23 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -34,11 +34,7 @@ using ExtBarrierOp = std::function; #define NVTE_MAX_REQUESTS 1024 #define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_CPU 2 -#define NVTE_MAX_NVLINK 8 - -#define UB_MEM_UC_CONTIG 1 -#define UB_MEM_MC_CREATED 2 -#define UB_MEM_ALLOCATED 4 +#define NVTE_MAX_NVLINK 32 #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -124,11 +120,8 @@ struct communicator { ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup // (_splitar init used) would be equal to (nvsize,0) for regular comm_create int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step - int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size) int sm_arch; - int num_nodes, my_node, - first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes) - int num2_nodes, my2_node, first2_node; // with num_nodes as a stride + int num_nodes, my_node; // max value for running block counters in hostflags int basecounter[userbuffers_op_types]; // NOLINT(*) @@ -136,20 +129,11 @@ struct communicator { void *mem_mr[NVTE_MAX_REGIONS]; - ub_request *fifo; - int nblocks, alignblock, minblock, asyncblocks, active_nreqs; - ub_request active_req[userbuffers_op_types]; // NOLINT(*) - int padding[7]; - volatile int head; - int padding2[15]; - volatile int tail; - // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) ExtAllgatherOp _allgather; ExtBarrierOp _barrier; ExtComm comm_world; - ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail ExtComm comm_intra; // full intranode (all ndev GPUS) #ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; @@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm); returned offset is offset of gpubuff relative to buffer registered */ -int pipe_rank(communicator *comm, - int step); // helper function to help walk across allreduce1 x allreduce2 groups - // data-parallel and tensor-parallel position within data and tensor - // groups would be preserved - int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc); /* returns handler and registers buffers. assumed to be collective i.e. you use same groups and dont mix buffers for different operations returns -1 if cant register (too many preregistered diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 8605447c61..48fb5d77d9 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -4,8 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#include - #include #include "../common.h" @@ -13,84 +11,6 @@ namespace transformer_engine { -namespace { - -/*! \brief Wrapper class for a shared library - * - * \todo Windows support - */ -class Library { - public: - explicit Library(const char *filename) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support - NVTE_ERROR("Shared library initialization is not supported with Windows"); -#else - handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); - NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - ~Library() { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support -#else - if (handle_ != nullptr) { - dlclose(handle_); - } -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - Library(const Library &) = delete; // move-only - - Library(Library &&other) noexcept { swap(*this, other); } - - Library &operator=(Library other) noexcept { - // Copy-and-swap idiom - swap(*this, other); - return *this; - } - - friend void swap(Library &first, Library &second) noexcept; - - void *get() noexcept { return handle_; } - - const void *get() const noexcept { return handle_; } - - /*! \brief Get pointer corresponding to symbol in shared library */ - void *get_symbol(const char *symbol) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support - NVTE_ERROR("Shared library initialization is not supported with Windows"); -#else - void *ptr = dlsym(handle_, symbol); - NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); - return ptr; -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - private: - void *handle_ = nullptr; -}; - -void swap(Library &first, Library &second) noexcept { - using std::swap; - swap(first.handle_, second.handle_); -} - -/*! \brief Lazily-initialized shared library for CUDA driver */ -Library &cuda_driver_lib() { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - constexpr char lib_name[] = "nvcuda.dll"; -#else - constexpr char lib_name[] = "libcuda.so.1"; -#endif - static Library lib(lib_name); - return lib; -} - -} // namespace - namespace cuda_driver { void *get_symbol(const char *symbol) { diff --git a/transformer_engine/common/util/cuda_nvml.cpp b/transformer_engine/common/util/cuda_nvml.cpp new file mode 100644 index 0000000000..0af9cd7411 --- /dev/null +++ b/transformer_engine/common/util/cuda_nvml.cpp @@ -0,0 +1,26 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "cuda_nvml.h" + +#include "shared_lib_wrapper.h" + +namespace transformer_engine { + +namespace cuda_nvml { + +/*! \brief Lazily-initialized shared library for CUDA NVML */ +Library &cuda_nvml_lib() { + constexpr char lib_name[] = "libnvidia-ml.so.1"; + static Library lib(lib_name); + return lib; +} + +void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); } + +} // namespace cuda_nvml + +} // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_nvml.h b/transformer_engine/common/util/cuda_nvml.h new file mode 100644 index 0000000000..14131a3cdd --- /dev/null +++ b/transformer_engine/common/util/cuda_nvml.h @@ -0,0 +1,69 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ + +#include + +#include + +#include "../common.h" +#include "../util/string.h" + +namespace transformer_engine { + +namespace cuda_nvml { + +/*! \brief Get pointer corresponding to symbol in CUDA NVML library */ +void *get_symbol(const char *symbol); + +/*! \brief Call function in CUDA NVML library + * + * The CUDA NVML library (libnvidia-ml.so.1 on Linux) may be different at + * compile-time and run-time. + * + * \param[in] symbol Function name + * \param[in] args Function arguments + */ +template +inline nvmlReturn_t call(const char *symbol, ArgTs... args) { + using FuncT = nvmlReturn_t(ArgTs...); + FuncT *func = reinterpret_cast(get_symbol(symbol)); + return (*func)(args...); +} + +/*! \brief Get NVML error string + * + * \param[in] rc NVML return code + */ +inline const char *get_nvml_error_string(nvmlReturn_t rc) { + using FuncT = const char *(nvmlReturn_t); + FuncT *func = reinterpret_cast(get_symbol("nvmlErrorString")); + return (*func)(rc); +} + +} // namespace cuda_nvml + +} // namespace transformer_engine + +#define NVTE_CHECK_CUDA_NVML(expr) \ + do { \ + const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \ + if (status_NVTE_CHECK_CUDA_NVML != NVML_SUCCESS) { \ + const char *desc_NVTE_CHECK_CUDA_NVML = \ + ::transformer_engine::cuda_nvml::get_nvml_error_string(status_NVTE_CHECK_CUDA_NVML); \ + NVTE_ERROR("NVML Error: ", desc_NVTE_CHECK_CUDA_NVML); \ + } \ + } while (false) + +#define VA_ARGS(...) , ##__VA_ARGS__ +#define NVTE_CALL_CHECK_CUDA_NVML(symbol, ...) \ + do { \ + NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ diff --git a/transformer_engine/common/util/shared_lib_wrapper.h b/transformer_engine/common/util/shared_lib_wrapper.h new file mode 100644 index 0000000000..3ccc8239b8 --- /dev/null +++ b/transformer_engine/common/util/shared_lib_wrapper.h @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ + +#include + +namespace transformer_engine { + +/*! \brief Wrapper class for a shared library + * + * \todo Windows support + */ +class Library { + public: + explicit Library(const char *filename) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); + NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + ~Library() { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support +#else + if (handle_ != nullptr) { + dlclose(handle_); + } +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + Library(const Library &) = delete; // move-only + + void *get() noexcept { return handle_; } + + const void *get() const noexcept { return handle_; } + + /*! \brief Get pointer corresponding to symbol in shared library */ + void *get_symbol(const char *symbol) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + void *ptr = dlsym(handle_, symbol); + NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); + return ptr; +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + private: + void *handle_ = nullptr; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e871228b80..d8fc76a2eb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -390,8 +390,7 @@ class CommOverlapHelper : torch::CustomClassHolder { CommOverlapHelper(); CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group, - std::optional inter_node_group); + std::optional intra_node_group); ~CommOverlapHelper(); diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 30126651ce..6d05869c36 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -26,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() { } // empty constructor for NVTE_UB_WITH_MPI=1 CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_domain_group, - std::optional inter_domain_group) { + std::optional intra_domain_group) { #ifndef NVTE_UB_WITH_MPI pgs.insert({"world", world_group}); myrank = pgs["world"]->getRank(); @@ -53,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, mynode = 0; numnodes = 1; } else { - // Intra-node group is different than the world group so there must be multiple nodes - NVTE_CHECK( - inter_domain_group.has_value(), - "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", - "identical to the world_group!"); - // Get node ID and number of nodes - NVTE_CHECK( - inter_domain_group.value()->getBackendType() == backend, - "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"inter", inter_domain_group.value()}); - mynode = pgs["inter"]->getRank(); - numnodes = pgs["inter"]->getSize(); + mynode = myrank / numlocal; + numnodes = numranks / numlocal; } } else { // Intra-node group is not set so we assume there is only 1 node diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 442837d767..0604847235 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -285,10 +285,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) - .def(py::init, - std::optional>(), + .def(py::init>(), py::call_guard(), py::arg("world_group"), - py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); + py::arg("intra_node_group") = py::none()); py::class_, transformer_engine::CommOverlapBase, transformer_engine::CommOverlapCore>(m, "CommOverlap") diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d0f9525135..84326f58ea 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -7,9 +7,6 @@ import os import pickle import warnings -import socket -import fcntl -import struct from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager @@ -177,85 +174,32 @@ def initialize_ub( world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - # We have single-node NVLink so we can color based on physical node hostnames. - # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and - # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on - # the chosen bootstrap backend. - mydomain = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME") - ) - if ifname is not None: - # Make sure the ifname found in the environment is a valid network interface - if ifname in [name for _, name in socket.if_nameindex()]: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - mydomain = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - finally: - s.close() - else: - ifname_warning = ( - f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - + " attempt to detect ranks on the same node by matching " - + "'socket.gethostname()', which is known to fail on virtual clusters like " - + "Kubernetes. If Userbuffers initialization fails, please set the " - + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network " - + "interface." - ) - warnings.warn(ifname_warning, UserWarning) - - # Allgather the domain colors across ranks and reduce to a list of unique domains - domain_per_rank_list = [None for _ in range(world_size)] - torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group) - unique_domains = [] - for domain in domain_per_rank_list: - if domain not in unique_domains: - unique_domains.append(domain) - num_domains = len(unique_domains) - + num_domains = world_size // tp_size + mydomain_idx = world_rank // tp_size if num_domains > 1: - # DP/TP model replicated on multiple NVLink domains - ranks_per_domain_list = [[] for _ in range(num_domains)] - mydomain_idx = -1 - for i, domain in enumerate(domain_per_rank_list): - domain_idx = unique_domains.index(domain) - ranks_per_domain_list[domain_idx].append(i) - if domain == mydomain: - mydomain_idx = domain_idx - assert mydomain_idx >= 0, "Internal TE error!" - - intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(num_domains) + ] + tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( ranks_per_domain_list, backend=bootstrap_backend ) - local_rank = torch.distributed.get_rank(intra_domain_group) - intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) - - inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( - [list(ranks) for ranks in zip(*ranks_per_domain_list)], - backend=bootstrap_backend, - ) - - helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group) + local_rank = torch.distributed.get_rank(tp_domain_group) + tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group) + helper = tex.CommOverlapHelper(world_group, tp_domain_group) else: # TP model on single NVLink domain, no replication, no data-parallelism mydomain_idx = 0 local_rank = world_rank - intra_domain_ranks = list(range(world_size)) + tp_domain_ranks = list(range(world_size)) helper = tex.CommOverlapHelper(world_group) if world_rank == 0: - print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) + print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", + f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n", end="", flush=True, ) From 45553c4aed9285392fe4f9987065a3e881ecea7a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 4 Mar 2025 18:15:57 -0800 Subject: [PATCH 168/239] minor fix for FA3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 4 ++-- transformer_engine/pytorch/attention.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index a9c4736918..2d16a1aca4 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -409,7 +409,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config = model_configs_infer[model] num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 # flash-attn v2 requires page_size >= 256 - if backend == "FlashAttention" and not _flash_attn_3_is_installed: + if backend == "FlashAttention" and _flash_attn_3_is_installed: config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 @@ -699,6 +699,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention" and not _flash_attn_3_is_installed: + if backend == "FlashAttention" and _flash_attn_3_is_installed: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 37160c77e1..f0b8745296 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -991,7 +991,7 @@ def get_attention_backend( use_unfused_attention = False selected_backend = "NoBackend" if use_flash_attention: - selected_backend = "FlashAttention ({str(flash_attention_backend)})" + selected_backend = f"FlashAttention ({str(flash_attention_backend)})" elif use_fused_attention: selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" elif use_unfused_attention: @@ -7858,7 +7858,7 @@ def forward( self.logger.info("Running with UnfusedDotProductAttention backend") else: use_flash_attention = _attention_backends["use_flash_attention"] - flash_attention_backend = _attention_backends["fused_attention_backend"] + flash_attention_backend = _attention_backends["flash_attention_backend"] use_fused_attention = _attention_backends["use_fused_attention"] fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] From 547d8dd865b60c9e080e1cfdabadbc72a2a73706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Agostinho?= <159524688+sagostinho-nvidia@users.noreply.github.com> Date: Wed, 5 Mar 2025 19:06:23 +0100 Subject: [PATCH 169/239] Don't touch nor send messages to the root logger. (#1380) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Signed-off-by: Sérgio Agostinho --- transformer_engine/jax/__init__.py | 4 +++- transformer_engine/pytorch/__init__.py | 4 +++- transformer_engine/pytorch/attention.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 31f597c37f..1be0ceb36a 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -12,6 +12,8 @@ from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +_logger = logging.getLogger(__name__) + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -36,7 +38,7 @@ def _load_library(): if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( + _logger.info( "Could not find package %s. Install transformer-engine using 'pip" " install transformer-engine[jax]==VERSION'", module_name, diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 966115c29e..753834e057 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -19,6 +19,8 @@ from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +_logger = logging.getLogger(__name__) + @functools.lru_cache(maxsize=None) def torch_version() -> tuple[int, ...]: @@ -49,7 +51,7 @@ def _load_library(): if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( + _logger.info( "Could not find package %s. Install transformer-engine using 'pip" " install transformer-engine[pytorch]==VERSION'", module_name, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index cc92c1377d..ba58fe32ce 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -98,7 +98,7 @@ _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") _stream_handler = logging.StreamHandler() _stream_handler.setFormatter(_formatter) -fa_logger = logging.getLogger() +fa_logger = logging.getLogger(__name__) fa_logger.setLevel(_log_level) if not fa_logger.hasHandlers(): fa_logger.addHandler(_stream_handler) From a3e6ed80b067c8ed10adf65e849ce209cfd613cb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 5 Mar 2025 23:52:49 +0530 Subject: [PATCH 170/239] Fix installation from PyPI wheels after a source install (#1526) * Fix wheel install after src install Signed-off-by: Kirthi Shankar Sivamani * Fix JAX imports Signed-off-by: Kirthi Shankar Sivamani * switch order of dirs for finding so Signed-off-by: Kirthi Shankar Sivamani * Use existing dir src build Signed-off-by: Kirthi Shankar Sivamani * Fix lint Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/build_ext.py | 9 +++++++-- examples/jax/encoder/common.py | 2 +- pylintrc | 2 -- tests/jax/conftest.py | 4 +++- tests/jax/test_fused_attn.py | 2 +- transformer_engine/common/__init__.py | 7 +++++++ transformer_engine/jax/__init__.py | 18 ++++++++++++++---- transformer_engine/jax/attention.py | 10 +++++----- .../jax/cpp_extensions/activation.py | 4 ++-- .../jax/cpp_extensions/attention.py | 6 +++--- .../jax/cpp_extensions/custom_call.py | 2 +- transformer_engine/jax/cpp_extensions/misc.py | 4 ++-- .../jax/cpp_extensions/normalization.py | 2 +- .../jax/cpp_extensions/quantization.py | 4 ++-- .../jax/cpp_extensions/softmax.py | 2 +- .../jax/cpp_extensions/transpose.py | 4 ++-- transformer_engine/jax/fp8.py | 6 +++--- transformer_engine/jax/setup.py | 2 +- transformer_engine/pytorch/__init__.py | 8 ++++++-- transformer_engine/pytorch/setup.py | 2 +- 20 files changed, 63 insertions(+), 37 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index a3243d087b..f0724f617e 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -94,7 +94,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: print(f"Time for build_ext: {total_time:.2f} seconds") -def get_build_ext(extension_cls: Type[setuptools.Extension]): +def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False): class _CMakeBuildExtension(extension_cls): """Setuptools command with support for CMake extension modules""" @@ -130,7 +130,12 @@ def run(self) -> None: self.extensions = all_extensions # Ensure that binaries are not in global package space. - target_dir = install_dir / "transformer_engine" + lib_dir = ( + "wheel_lib" + if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib + else "" + ) + target_dir = install_dir / "transformer_engine" / lib_dir target_dir.mkdir(exist_ok=True, parents=True) for ext in Path(self.build_lib).glob("*.so"): diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index 93dbd408ea..2785deac0c 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -4,7 +4,7 @@ """Shared functions for the encoder tests""" from functools import lru_cache -from transformer_engine.transformer_engine_jax import get_device_compute_capability +from transformer_engine_jax import get_device_compute_capability @lru_cache diff --git a/pylintrc b/pylintrc index 4af0c6b427..50f85fad9d 100644 --- a/pylintrc +++ b/pylintrc @@ -4,8 +4,6 @@ extension-pkg-whitelist=flash_attn_2_cuda, transformer_engine_torch, transformer_engine_jax -extension-pkg-allow-list=transformer_engine.transformer_engine_jax - disable=too-many-locals, too-few-public-methods, too-many-public-methods, diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index d1558710c7..663a954184 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -6,7 +6,9 @@ import jax import pytest -from transformer_engine.transformer_engine_jax import get_device_compute_capability + +import transformer_engine.jax +from transformer_engine_jax import get_device_compute_capability @pytest.fixture(autouse=True, scope="function") diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 037e364a7e..745f1cc633 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -38,7 +38,7 @@ ReorderStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper -from transformer_engine.transformer_engine_jax import ( +from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, ) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index efcd4dc0b0..a8c845efd8 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -83,6 +83,13 @@ def _load_library(): """Load shared library with Transformer Engine C extensions""" so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}" + if not so_path.exists(): + so_path = ( + get_te_path() + / "transformer_engine" + / "wheel_lib" + / f"libtransformer_engine.{_get_sys_extension()}" + ) if not so_path.exists(): so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}" assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}" diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 1be0ceb36a..4e38438a97 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -5,7 +5,10 @@ # pylint: disable=wrong-import-position,wrong-import-order +import sys import logging +import importlib +import importlib.util import ctypes from importlib.metadata import version @@ -49,13 +52,20 @@ def _load_library(): so_dir = get_te_path() / "transformer_engine" so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: - so_dir = get_te_path() - so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + try: + so_dir = get_te_path() / "transformer_engine" / "wheel_lib" + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + except StopIteration: + so_dir = get_te_path() + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) - return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + spec = importlib.util.spec_from_file_location(module_name, so_path) + solib = importlib.util.module_from_spec(spec) + sys.modules[module_name] = solib + spec.loader.exec_module(solib) -_TE_JAX_LIB_CTYPES = _load_library() +_load_library() from . import flax from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from .fp8 import NVTE_FP8_COLLECTION_NAME diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 9b93faeb55..06629291da 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -13,11 +13,11 @@ import jax.numpy as jnp from flax.linen import make_attention_mask -from transformer_engine.transformer_engine_jax import NVTE_Bias_Type -from transformer_engine.transformer_engine_jax import NVTE_Mask_Type -from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout -from transformer_engine.transformer_engine_jax import NVTE_QKV_Format -from transformer_engine.transformer_engine_jax import nvte_get_qkv_format +from transformer_engine_jax import NVTE_Bias_Type +from transformer_engine_jax import NVTE_Mask_Type +from transformer_engine_jax import NVTE_QKV_Layout +from transformer_engine_jax import NVTE_QKV_Format +from transformer_engine_jax import nvte_get_qkv_format from . import cpp_extensions as tex diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 076ec98aba..704740c56d 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -13,8 +13,8 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import NVTE_Activation_Type +import transformer_engine_jax +from transformer_engine_jax import NVTE_Activation_Type from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 409f08c7db..103f97827f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -17,6 +17,9 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi + +import transformer_engine_jax +from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, @@ -26,9 +29,6 @@ SequenceDescriptor, ) -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend - from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import ( diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 6f6c9962cf..422d81b267 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -7,7 +7,7 @@ import jax from jax.interpreters import mlir -from transformer_engine import transformer_engine_jax +import transformer_engine_jax from .misc import is_ffi_enabled diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 3ec6502152..4f65a2c3c7 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -15,8 +15,8 @@ from jax import dtypes from jax.interpreters.mlir import dtype_to_ir_type -from transformer_engine.transformer_engine_jax import DType as TEDType -from transformer_engine import transformer_engine_jax +from transformer_engine_jax import DType as TEDType +import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 1107dd3a0f..50248649ba 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -15,7 +15,7 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi -from transformer_engine import transformer_engine_jax +import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 2f29a64f18..f3ecf5e230 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -11,8 +11,8 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType +import transformer_engine_jax +from transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index dba1f504da..42c6919d92 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -14,7 +14,7 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi -from transformer_engine import transformer_engine_jax +import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index bb9b104e7e..8353414235 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -13,8 +13,8 @@ from jax.sharding import PartitionSpec, NamedSharding from jax import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType +import transformer_engine_jax +from transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index f2dbd3b131..04ac6dd57d 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -14,9 +14,9 @@ from flax.core.frozen_dict import FrozenDict from flax.linen import fp8_ops -from transformer_engine.transformer_engine_jax import DType -from transformer_engine.transformer_engine_jax import get_cublasLt_version -from transformer_engine.transformer_engine_jax import ( +from transformer_engine_jax import DType +from transformer_engine_jax import get_cublasLt_version +from transformer_engine_jax import ( get_cuda_version, get_device_compute_capability, ) diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 0f69939f36..4f5cc4df20 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -37,7 +37,7 @@ from pybind11.setup_helpers import build_ext as BuildExtension os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) +CMakeBuildExtension = get_build_ext(BuildExtension, True) if __name__ == "__main__": diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 753834e057..888836ec7f 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -62,8 +62,12 @@ def _load_library(): so_dir = get_te_path() / "transformer_engine" so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: - so_dir = get_te_path() - so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + try: + so_dir = get_te_path() / "transformer_engine" / "wheel_lib" + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + except StopIteration: + so_dir = get_te_path() + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) spec = importlib.util.spec_from_file_location(module_name, so_path) solib = importlib.util.module_from_spec(spec) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 20503fea2f..4499c28826 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -35,7 +35,7 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) +CMakeBuildExtension = get_build_ext(BuildExtension, True) if __name__ == "__main__": From 6ff7b7042d5351b8212926d81c07d3fc6a8fe576 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:43:26 -0800 Subject: [PATCH 171/239] [PyTorch] Move Lightning-Thunder integration test to L1 (#1536) Move Lightning-Thunder integration test to L1 Signed-off-by: Tim Moon --- qa/L0_pytorch_unittest/test.sh | 5 ++--- qa/L1_pytorch_thunder_integration/test.sh | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 qa/L1_pytorch_thunder_integration/test.sh diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17fb4d1827..56d668bd12 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -2,11 +2,11 @@ # # See LICENSE for license information. +set -x : ${TE_PATH:=/opt/transformerengine} -: ${LIGHTNING_THUNDER_PATH:=/opt/pytorch/lightning-thunder} -pip install pytest==8.2.1 pytest-benchmark==5.1.0 +pip install pytest==8.2.1 FAIL=0 @@ -25,6 +25,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 -pytest -v -s ${LIGHTNING_THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py exit $FAIL diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh new file mode 100644 index 0000000000..1737ca9ba1 --- /dev/null +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -x + +: ${THUNDER_PATH:=/opt/pytorch/lightning-thunder} + +pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 +python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py + +# Check return code +# Note: Return code 5 is fine. Lightning tests are skipped on systems +# without FP8 support and Pytest returns 5 if no tests are run. +RC=$? +if [ ${RC} -eq 5 ]; then + RC=0 +fi +exit ${RC} From bd278fffa133a135cd0879e179fb48ba65184bf5 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 6 Mar 2025 08:57:12 -0800 Subject: [PATCH 172/239] [PyTorch] Enable MXFP8 LayerNorm and RMSNorm (#1487) * Enable MXFP8 LayerNorm and RMSNorm Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix compilation Signed-off-by: Kirthi Shankar Sivamani * Fix envvar Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Tim Moon Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/csrc/extensions/normalization.cpp | 162 ++++++++++-------- 1 file changed, 88 insertions(+), 74 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 66ad03381c..bb011faf98 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/util/system.h" #include "extensions.h" namespace transformer_engine::pytorch { @@ -70,80 +71,85 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, } std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, - float eps, py::object ln_out, py::handle quantizer, + float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { using namespace transformer_engine::pytorch; using namespace transformer_engine; + // Input and param tensors auto none = py::none(); - const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); - - TensorWrapper bias_tensor; - MaybeTensor bias_grad = std::nullopt; + const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); + TensorWrapper bias_cu; if (bias.has_value()) { - bias_tensor = makeTransformerEngineTensor(*bias); + bias_cu = makeTransformerEngineTensor(*bias); } // Tensor dimensions - size_t N = static_cast(input_tensor.size(0)); - size_t H = static_cast(input_tensor.size(1)); - std::vector size = {N, H}; + const size_t N = static_cast(input_cu.size(0)); + const size_t H = static_cast(input_cu.size(1)); + const std::vector size = {N, H}; - // Construct Transformer Engine tensors + // Tensors to save for backward pass at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + TensorWrapper mu_cu = makeTransformerEngineTensor(mu); + TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); - TensorWrapper ln_out_tensor; + // Output tensor std::unique_ptr my_quantizer = convert_quantizer(quantizer); - py::object ln_output; + TensorWrapper out_cu; + if (out.is_none()) { + std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + } else { + out_cu = makeTransformerEngineTensor(out, quantizer); + } + // Determine whether to avoid fused kernel + bool force_unfused_kernel = false; if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - // Use high precision output from normalization - NoneQuantizer q{none}; - std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype); - } else { - if (ln_out.is_none()) { - std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); - } else { - ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + // TE only supports MXFP8 norm with cuDNN backend + force_unfused_kernel = true; + } else if (N % 128 != 0 || H % 128 != 0) { + // cuDNN norm requires full tile for MXFP8 + force_unfused_kernel = true; } } - TensorWrapper mu_cu = makeTransformerEngineTensor(mu); - TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + TensorWrapper unquantized_out_cu; + if (force_unfused_kernel) { + NoneQuantizer q{none}; + py::object unquantized_out; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } + TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; - // Query workspace sizes + // Query workspace size transformer_engine::TensorWrapper workspace; - nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, - ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - // Allocate workspaces + // Allocate workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, - ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - TensorWrapper cast_out_tensor; - if (ln_out.is_none()) { - std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); - } else { - cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); - } - - nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + // Quantize output if using unfused kernel + if (force_unfused_kernel) { + nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); } - return {ln_out, py::cast(mu), py::cast(rsigma)}; + return {out, py::cast(mu), py::cast(rsigma)}; } std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -187,69 +193,77 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, } std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, - py::object ln_out, py::handle quantizer, - transformer_engine::DType otype, const int sm_margin, + py::object out, py::handle quantizer, + transformer_engine::DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { using namespace transformer_engine::pytorch; using namespace transformer_engine; + // Input and param tensors auto none = py::none(); - const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); + const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); // Tensor dimensions - size_t N = static_cast(input_tensor.shape().data[0]); - size_t H = static_cast(input_tensor.shape().data[1]); + const size_t N = static_cast(input_cu.shape().data[0]); + const size_t H = static_cast(input_cu.shape().data[1]); + const std::vector size = {N, H}; - // Construct Transformer Engine tensors + // Tensors to save for backward pass auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - std::vector size = {N, H}; - TensorWrapper ln_out_tensor; + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + + // Output tensor std::unique_ptr my_quantizer = convert_quantizer(quantizer); - py::object ln_output; + TensorWrapper out_cu; + if (out.is_none()) { + std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + } else { + out_cu = makeTransformerEngineTensor(out, quantizer); + } + // Determine whether to avoid fused kernel + bool force_unfused_kernel = false; if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - // Use high precision output from normalization - NoneQuantizer q{none}; - std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype); - } else { - if (ln_out.is_none()) { - std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); - } else { - ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + // TE only supports MXFP8 norm with cuDNN backend + force_unfused_kernel = true; + } else if (N % 128 != 0 || H % 128 != 0) { + // cuDNN norm requires full tile for MXFP8 + force_unfused_kernel = true; } } - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + TensorWrapper unquantized_out_cu; + if (force_unfused_kernel) { + NoneQuantizer q{none}; + py::object unquantized_out; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } + TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; - // Query workspace sizes + // Query workspace size transformer_engine::TensorWrapper workspace; - nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), - rsigma_cu.data(), workspace.data(), + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - // Allocate workspaces + // Allocate workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), - rsigma_cu.data(), workspace.data(), + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - TensorWrapper cast_out_tensor; - if (ln_out.is_none()) { - std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); - } else { - cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); - } - - nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + // Quantize output if using unfused kernel + if (force_unfused_kernel) { + nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); } - return {ln_out, py::none(), py::cast(rsigma)}; + return {out, py::none(), py::cast(rsigma)}; } From 74983b3689f9469f7b2474d61fb0216afb3acb5b Mon Sep 17 00:00:00 2001 From: Nicolas Castet <26874160+nvcastet@users.noreply.github.com> Date: Thu, 6 Mar 2025 12:57:29 -0600 Subject: [PATCH 173/239] Fix UB with MPI init (#1538) Signed-off-by: Nicolas Castet --- .../comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 14ff853266..e52cdd8a1f 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -280,7 +280,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int fd; volatile uint32_t abortFlag = 0; IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafeb000 + (*comm)->my_node + (*comm)->ar2_firstgpu; + uint64_t opId = 0xdeadcafe0000 + (*comm)->my_node; ipcSocketResult_t ret = ipcSocketSuccess; IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); (*comm)->_barrier((*comm)->comm_world); @@ -416,6 +416,8 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe // find internode numbers and make internode communicator NVTE_CHECK_CUDA(cudaFree(0)); int mynode, numnodes; + mynode = myrank / numlocal; + numnodes = numranks / numlocal; // finally call the abstracted constructor with MPI info return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, @@ -549,7 +551,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * volatile uint32_t abortFlag = 0; IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafebeef + comm->my_node; + uint64_t opId = 0xdeadcafe0000 + comm->my_node; ipcSocketResult_t ret = ipcSocketSuccess; // All-gather POSIX file descriptors across local ranks From e1c4f51ed24b8eb18066fac0d3236786ae90860b Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 6 Mar 2025 11:18:12 -0800 Subject: [PATCH 174/239] make sure dout is contiguous (#1539) Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ba58fe32ce..0638c040b1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2858,6 +2858,7 @@ def backward(ctx, dout): # [b, np, sq] -> [b, np, sq, 1] or # [t, np] -> [t, np, 1] softmax_lse.unsqueeze_(-1) + dout = dout.contiguous() dq = None dout_dtype = dout.dtype From 971001394167f17e2a9c669983868cd1eb0ae7fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 6 Mar 2025 20:28:27 +0100 Subject: [PATCH 175/239] [PyTorch] Fix issue when last input in GroupedLinear is empty. * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more sensitive tests Signed-off-by: Pawel Gadzinski * typo fix and skip test on blackwell fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_sanity.py | 50 +++++++++++++++++++ .../pytorch/csrc/extensions/gemm.cpp | 11 +++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index d3bf34943d..1e6250f26f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -25,6 +25,7 @@ from transformer_engine.pytorch import ( LayerNormLinear, Linear, + GroupedLinear, LayerNormMLP, TransformerLayer, RMSNorm, @@ -532,6 +533,55 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ assert out.shape == (num_tokens, ffn_hidden_size) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes_with_zero) +@pytest.mark.parametrize("model", ["small", "weird"]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("use_bias", all_boolean) +@pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) +@pytest.mark.parametrize("num_gemms", [4]) +def test_sanity_grouped_linear( + dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split +): + config = model_configs[model] + ffn_hidden_size = 4 * config.hidden_size + # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. + bs = bs * 16 + num_tokens = bs * config.seq_len * (num_gemms - 1) + + if fp8_recipe is not None: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8(): + pytest.skip("Grouped linear does not support MXFP8") + if not config.is_fp8_supported(): + pytest.skip("Model config does not support FP8") + + use_fp8 = fp8_recipe is not None + with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): + te_grouped_linear = GroupedLinear( + num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + ).cuda() + + inp_hidden_states = torch.randn( + num_tokens, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + m_splits = [bs * config.seq_len] * num_gemms + if empty_split == "first": + m_splits[0] = 0 + elif empty_split == "last": + m_splits[-1] = 0 + elif empty_split == "middle": + m_splits[num_gemms // 2] = 0 + + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + out = te_grouped_linear(inp_hidden_states, m_splits) + loss = out.sum() + loss.backward() + assert out.shape == (num_tokens, ffn_hidden_size) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small", "weird"]) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 54bd52f136..53fed04735 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -329,9 +329,13 @@ std::optional> te_general_grouped_gemm( at::Tensor out_tensor; auto size_t_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); + bool D_numel_is_zero = false; std::vector D_shape; for (size_t t : size_t_shape) { D_shape.push_back(t); + if (t == 0) { + D_numel_is_zero = true; + } } auto dtype = GetATenDType(D_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); @@ -339,7 +343,12 @@ std::optional> te_general_grouped_gemm( if (output_data_ptr == nullptr) { out_tensor = at::empty(D_shape, opts); } else { - out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + // We need to check !D_numel_is_zero because if the final input portion has zero elements, + // output_data_ptr would point beyond the allocated memory of D. This would cause + // at::from_blob to fail as it would reference memory not allocated by CUDA. + if (!D_numel_is_zero) { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + } } char* char_ptr = reinterpret_cast(output_data_ptr); char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); From 3db46d6c27b13efa1aa849abff5b19e92babff4d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 11:38:35 -0800 Subject: [PATCH 176/239] more minor fixes for FA3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 2d16a1aca4..e91286302c 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -409,7 +409,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config = model_configs_infer[model] num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 # flash-attn v2 requires page_size >= 256 - if backend == "FlashAttention" and _flash_attn_3_is_installed: + if backend == "FlashAttention" and not _flash_attn_3_is_installed: config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 @@ -490,9 +490,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference") # generate data for all requests - assert ( - config.max_seqlen_q == config.max_seqlen_kv - ), "This test only simulates max_seqlen_q = max_seqlen_kv." full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs") # generate reference results @@ -699,6 +696,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention" and _flash_attn_3_is_installed: + if backend == "FlashAttention" and not _flash_attn_3_is_installed: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv From 831866a4ff4fc4d9052b9829b81362ceac6e9b3e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 11:46:17 -0800 Subject: [PATCH 177/239] test page_size=1 for FA3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index e91286302c..55523276a2 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -420,7 +420,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda page_size = None total_num_pages = None if is_paged: - page_size = 256 if backend == "FlashAttention" and not _flash_attn_3_is_installed else 16 + page_size = 256 if backend == "FlashAttention" and not _flash_attn_3_is_installed else 1 config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) else: From 63241ad281647e2391f4e1553d002decb88d4b20 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 12:45:42 -0800 Subject: [PATCH 178/239] fix t3hd/th3d strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/utils.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a079f5a8fc..516f7b84c5 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -459,7 +459,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl( case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: offsets_k[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_v[tid] = offsets_v[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: From 13bd745bff68e26a096a5e6b998446d78a08f779 Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:56:46 -0800 Subject: [PATCH 179/239] Remove cudaStreamSync. call from transformer_engine.cpp (#1518) * Remove cudaStreamSync. call Signed-off-by: Vasudevan Rengasamy * Use cudaMemsetAsync instead of cudaMemcpyAsync Signed-off-by: Vasudevan Rengasamy * Update transformer_engine/common/transformer_engine.cpp Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/common/transformer_engine.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index faf6ec990d..54d5b0b5bf 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -407,8 +407,6 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { } // Set amax to 0 if allocated if (t.amax.dptr != nullptr) { - float zero = 0.0f; - cudaMemcpyAsync(t.amax.dptr, &zero, sizeof(float), cudaMemcpyHostToDevice, stream); + cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); } - cudaStreamSynchronize(stream); } From de06a34cebde85736ddb5cf9056c0db4c3db3465 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Thu, 6 Mar 2025 14:58:10 -0800 Subject: [PATCH 180/239] Add NVTX ranges to FP8 amax AR and grad output preprocessing (#1530) Add NVTX ranges Signed-off-by: Jaemin Choi Co-authored-by: Jaemin Choi Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++++ transformer_engine/pytorch/module/linear.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 007821038f..2608fedeb1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -522,6 +522,7 @@ def backward( if ctx.grad_output_quantizer is not None: ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, grad_bias, @@ -531,6 +532,7 @@ def backward( ctx.parallel_mode == "row", ctx.grad_output_quantizer, ) + nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") # Prepare GEMM input # Note: Perform tensor-parallel communication if needed @@ -747,7 +749,9 @@ def backward( wgrad = None if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers # if ctx.fp8 and not isinstance(weight, QuantizedTensor): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 83dc652c62..f07cfb487b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -427,6 +427,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: Cast to expected dtype and perform tensor-parallel communication if ctx.grad_output_quantizer is not None: ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, grad_bias, @@ -436,6 +437,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.parallel_mode == "row", ctx.grad_output_quantizer, ) + nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") # Prepare input tensor # Note: Perform tensor-parallel communication if needed @@ -623,7 +625,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], wgrad = None if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers if ctx.fp8 and not isinstance(weight, QuantizedTensor): From a37058a314c0cbf1bd7d34051854dad167d524a7 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 22:08:35 -0800 Subject: [PATCH 181/239] fix ckpt recompute and fa3 k_scale Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 50fd3368cf..d3baa528c7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6107,9 +6107,9 @@ def convert_to_torch_float8(tensor, dtype): fa_3_optional_forward_kwargs["q_descale"] = ( query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_q) ) - fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze( - 0 - ).repeat(batch_size, num_heads_k) + fa_3_optional_forward_kwargs["k_descale"] = ( + key_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) + ) fa_3_optional_forward_kwargs["v_descale"] = ( value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) ) @@ -7869,7 +7869,6 @@ def forward( raise ValueError("No dot product attention support for the provided inputs!") # run attention - output = None if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = get_alibi( @@ -7878,7 +7877,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, ) - output = self.flash_attention( + return self.flash_attention( query_layer, key_layer, value_layer, @@ -7917,8 +7916,9 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) + #checkpoint_core_attention=False if checkpoint_core_attention: - output = self._checkpointed_attention_forward( + return self._checkpointed_attention_forward( self.fused_attention, query_layer, key_layer, @@ -7943,9 +7943,10 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, inference_params=inference_params, ) - output = self.fused_attention( + return self.fused_attention( query_layer, key_layer, value_layer, @@ -7983,7 +7984,7 @@ def forward( if use_unfused_attention: if checkpoint_core_attention: - output = self._checkpointed_attention_forward( + return self._checkpointed_attention_forward( self.unfused_attention, query_layer, key_layer, @@ -7999,7 +8000,7 @@ def forward( alibi_slopes=alibi_slopes, inference_params=inference_params, ) - output = self.unfused_attention( + return self.unfused_attention( query_layer, key_layer, value_layer, From 09c2f394a3cd1978e364dfbcdd1a0ab19c96c4ef Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 22:35:44 -0800 Subject: [PATCH 182/239] raise dynamo recompile limit for test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e56bb8868c..ebf41cbee5 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -60,6 +60,7 @@ _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() +torch._dynamo.config.recompile_limit = 16 class ModelConfig: def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): From 5e2f2a95a1a50c431f7645a35c330ebfdf3e5996 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 22:40:07 -0800 Subject: [PATCH 183/239] remove thunder test from L0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 4ea8da8faf..51b2020736 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -26,6 +26,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || FAIL=1 -pytest -v -s ${LIGHTNING_THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py exit $FAIL From d4d82dddba27ea7b629dad3883eb41de08340a55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Mar 2025 06:41:35 +0000 Subject: [PATCH 184/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_numerics.py | 1 + transformer_engine/pytorch/attention.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index ebf41cbee5..bde4f7f5f7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -62,6 +62,7 @@ torch._dynamo.config.recompile_limit = 16 + class ModelConfig: def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): self.hidden_size = hidden_size diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d3baa528c7..def6cf531c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6107,9 +6107,9 @@ def convert_to_torch_float8(tensor, dtype): fa_3_optional_forward_kwargs["q_descale"] = ( query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_q) ) - fa_3_optional_forward_kwargs["k_descale"] = ( - key_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) - ) + fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze( + 0 + ).repeat(batch_size, num_heads_k) fa_3_optional_forward_kwargs["v_descale"] = ( value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) ) @@ -7916,7 +7916,7 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) - #checkpoint_core_attention=False + # checkpoint_core_attention=False if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, From 2d9a882574cc077a9fbf68f4a62602fe3c8405b1 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 23:11:02 -0800 Subject: [PATCH 185/239] fix FA selection logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index def6cf531c..1140ea7e32 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -446,10 +446,10 @@ def get_attention_backend( flash_attention_backend = None use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if (use_flash_attention_2 and _flash_attn_is_installed) or ( - use_flash_attention_3 and _flash_attn_is_installed - ): - logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_2 and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_3 and _flash_attn_3_is_installed: + logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0") if not use_fused_attention: logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") if not use_unfused_attention: @@ -558,7 +558,7 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: if (use_flash_attention_2 and _flash_attn_is_installed) or ( - use_flash_attention_3 and _flash_attn_is_installed + use_flash_attention_3 and _flash_attn_3_is_installed ): logger.debug("Disabling FlashAttention as it does not support MLA.") use_flash_attention = False @@ -600,7 +600,7 @@ def get_attention_backend( use_unfused_attention = False if pad_between_seqs: if (use_flash_attention_2 and _flash_attn_is_installed) or ( - use_flash_attention_3 and _flash_attn_is_installed + use_flash_attention_3 and _flash_attn_3_is_installed ): logger.debug( "Disabling FlashAttention for qkv_format = thd when there is " @@ -636,28 +636,32 @@ def get_attention_backend( logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) + use_flash_attention = False if "bottom_right" in attn_mask_type: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal_bottom_right masking" ) + use_flash_attention = False elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal masking for cross-attention" ) + use_flash_attention = False elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: logger.debug( "Disabling FlashAttention as it does not support context parallelism with bias" " type of %s", core_attention_bias_type, ) + use_flash_attention = False elif qkv_format == "thd" and core_attention_bias_type != "no_bias": logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " attention bias for THD format" ) - use_flash_attention = False + use_flash_attention = False if context_parallel and use_fused_attention: if "bottom_right" in attn_mask_type: @@ -711,7 +715,7 @@ def get_attention_backend( # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": if (use_flash_attention_2 and _flash_attn_is_installed) or ( - use_flash_attention_3 and _flash_attn_is_installed + use_flash_attention_3 and _flash_attn_3_is_installed ): logger.debug("Disabling FlashAttention for arbitrary mask") use_flash_attention = False From bb5613c1ceff985a74715ee8f2ca29c2db9ab47c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 6 Mar 2025 23:21:42 -0800 Subject: [PATCH 186/239] fix FA3 q_descale shape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 1140ea7e32..c6296c2eaf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6109,7 +6109,7 @@ def convert_to_torch_float8(tensor, dtype): num_heads_q = query_layer.shape[-2] num_heads_k = key_layer.shape[-2] fa_3_optional_forward_kwargs["q_descale"] = ( - query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_q) + query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) ) fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze( 0 From 48b8eea67a63b3e1377f65ab897a8f52ee2c2947 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 8 Mar 2025 01:02:20 +0530 Subject: [PATCH 187/239] [PyTorch] Don't set FP8 data to `None` when saving base tensors (#1548) Don't set data to null Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/tensor/_internal/float8_tensor_base.py | 2 -- .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 8ae45c9375..b0b6f98e6c 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -105,8 +105,6 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor """ tensors = [self._data, self._transpose] - self._data = None - self._transpose = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index ea7fc3cf2f..bd581feab1 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -100,8 +100,6 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB """ tensors = [self._rowwise_data, self._columnwise_data] - self._rowwise_data = None - self._columnwise_data = None return tensors, self def restore_from_saved( From 44c8fd0f3ac4131bcc6a03f47f3d1acf7858b527 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 8 Mar 2025 01:11:04 +0530 Subject: [PATCH 188/239] Add user to TE CI (#1547) Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index cd20e8b2d0..681b662036 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -43,6 +43,7 @@ jobs: || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' || github.actor == 'jberchtold-nvidia' + || github.actor == 'sanandaraj5597' || github.actor == 'negvet' ) steps: From 2ad5da952e42c6fe7bd09bee8810f7f6c195cbd8 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 7 Mar 2025 11:43:50 -0800 Subject: [PATCH 189/239] [PyTorch] Fix incorrect docstrings in tensor saving functions (#1549) Fix incorrect docstrings in tensor saving functions Signed-off-by: Tim Moon --- .../pytorch/tensor/_internal/float8_tensor_base.py | 7 +------ .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 7 +------ transformer_engine/pytorch/tensor/float8_tensor.py | 7 +------ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 7 +------ 4 files changed, 4 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index b0b6f98e6c..bb01b1ee8b 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -98,12 +98,7 @@ def get_metadata(self) -> Dict[str, Any]: } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: - """Prepare the tensor base for saving for backward - - After calling this, the tensor instance does not hold any - data. - - """ + """Prepare the tensor base for saving for backward""" tensors = [self._data, self._transpose] return tensors, self diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index bd581feab1..b818638d02 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -93,12 +93,7 @@ def get_metadata(self) -> Dict[str, Any]: } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: - """Prepare the tensor base for saving for backward - - After calling this, the tensor instance does not hold any - data. - - """ + """Prepare the tensor base for saving for backward""" tensors = [self._rowwise_data, self._columnwise_data] return tensors, self diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 333b8d1733..5944039cf0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -349,12 +349,7 @@ def clear(self): self._transpose_invalid = True def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: - """Prepare the tensor base for saving for backward - - After calling this, the tensor instance does not hold any - data. - - """ + """Prepare the tensor base for saving for backward""" return [self], None @classmethod diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 940f2ae46f..db369de803 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -286,12 +286,7 @@ def clear(self): self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: - """Prepare the tensor base for saving for backward - - After calling this, the tensor instance does not hold any - data. - - """ + """Prepare the tensor base for saving for backward""" return [self], None @classmethod From 2a95efd39128955081c60b67d49351d89f003324 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:39:18 -0800 Subject: [PATCH 190/239] CP implementation refinement for BSHD/SBHD format (#1523) * fix recompilation of out and lse correction in p2p+bshd/sbhd Signed-off-by: Xiaowei Ren * fix recompilation of get_seq_chunk_ids_for_reordering Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix recomplilation of reorder_seq_chunks_for_a2a Signed-off-by: Xiaowei Ren * recover a change Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change to softmax_lse correction Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cache cu_seqlens for BSHD/SBHD format Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * do not need to allocate out buffer for BSHD/SBHD Signed-off-by: Xiaowei Ren * code refactoring Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix Signed-off-by: Xiaowei Ren * refactor init out correction Signed-off-by: Xiaowei Ren * fix a docstring Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * code refactoring Signed-off-by: Xiaowei Ren * fix init out correct dtype Signed-off-by: Xiaowei Ren * add pad_between_seqs to DPA API Signed-off-by: Xiaowei Ren * add pad_between_seqs to the API of MHA and transformer layer Signed-off-by: Xiaowei Ren * add pad_between_seqs to the API of MHA and transformer layer Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 344 ++++++++++++++-------- transformer_engine/pytorch/transformer.py | 5 + 2 files changed, 228 insertions(+), 121 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0638c040b1..537b43496f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1589,24 +1589,52 @@ def flash_attn_p2p_communicate( return send_recv_reqs +@jit_fuser +def flash_attn_fwd_out_correction_init( + out_init_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_init_step: torch.Tensor, + seq_dim: int, +): + """Merge partial outputs of the first step in Attention with context parallelism""" + softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_init_step * softmax_lse_corrected_exp + return out_corrected.to(out_init_step.dtype) + + @jit_fuser def flash_attn_fwd_out_correction( out: torch.Tensor, out_per_step: torch.Tensor, softmax_lse: torch.Tensor, softmax_lse_per_step: torch.Tensor, - movedim_src: int, - movedim_dst: int, + seq_dim: int, ): """Merge partial outputs of each step in Attention with context parallelism""" - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim( - movedim_src, movedim_dst - ) + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) out_corrected = out_per_step * softmax_lse_corrected_exp out.add_(out_corrected) +@jit_fuser +def flash_attn_fwd_second_half_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Merge second half of partial outputs of each step in Attention with context parallelism""" + out_ = out.select(seq_dim, 1) + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :] + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out_.add_(out_corrected) + + @jit_fuser def flash_attn_fwd_softmax_lse_correction( softmax_lse: torch.Tensor, @@ -1619,6 +1647,19 @@ def flash_attn_fwd_softmax_lse_correction( softmax_lse.copy_(new_scale) +@jit_fuser +def flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): + """Merge second half of softmax stats of each step in Attention with context parallelism""" + softmax_lse_ = softmax_lse[..., 1, :] + max_scale = torch.max(softmax_lse_, softmax_lse_per_step) + min_scale = torch.min(softmax_lse_, softmax_lse_per_step) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) + softmax_lse_.copy_(new_scale) + + @jit_fuser def get_cu_seqlens_on_cp_rank( cu_seqlens: torch.Tensor, @@ -1646,46 +1687,59 @@ def get_cu_seqlens_on_cp_rank( @jit_fuser -def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): +def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device): """ Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. - To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks - before or after CP communications (e.g., all-gather, all-to-all). This function is to compute - sequence chunk ids for reordering. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to + be contigupus before attention compute. This function is to compute sequence chunk ids for + reordering. """ chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) - if to_contiguous: - for rank in range(cp_size): - chunk_ids[rank] = 2 * rank - chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 - else: - for rank in range(cp_size): - chunk_ids[2 * rank] = rank - chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 return chunk_ids @jit_fuser -def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): - """Reorder sequence chunk for A2A communication.""" - if before_attn: - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] - x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) - # reorder the sequence chunks - x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) - else: - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.movedim(seq_dim, 0).contiguous() - # reorder the sequence chunks - x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] - x = x.view(cp_size, 2, *x.shape[1:]) +def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + We need to reorder sequence chunks back to discontiguous after attention compute. This function + is to compute sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +@jit_fuser +def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication before attention compute.""" + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + return x + + +@jit_fuser +def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication after attention compute.""" + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -1713,8 +1767,8 @@ def flash_attn_a2a_communicate( a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] # reorder the sequence chunks - x = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + x = reorder_seq_chunks_for_a2a_before_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size ) # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] @@ -1740,8 +1794,8 @@ def flash_attn_a2a_communicate( # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks - a2a_inputs[i] = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size ) if i > 1: with torch.cuda.stream(cp_stream): @@ -1800,6 +1854,25 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer +_cu_seqlens_info_with_cp_cache = {} + + +def _get_cu_seqlens_info_with_cp( + batch_size: int, + max_seqlen: int, + cp_size: int, + cu_seqlens: torch.Tensor, +): + """Cumulative sequence lengths with CP being considered.""" + global _cu_seqlens_info_with_cp_cache + if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache: + _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = ( + cu_seqlens // cp_size, + cu_seqlens // (cp_size * 2), + ) + return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -1839,6 +1912,7 @@ def forward( cp_global_ranks, cp_stream, quantizers, + pad_between_seqs, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") @@ -1871,27 +1945,28 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type + batch_dim = None seq_dim = None + cu_seqlens_q_half, cu_seqlens_kv_half = None, None if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None + if use_fused_attention: + batch_dim = qkv_format.index("b") + cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q + ) + cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv + ) else: qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size - pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal( - cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1] - ) - pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal( - cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1] - ) max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size - cu_seqlens_q_padded = ( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size - ) - cu_seqlens_kv_padded = ( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size - ) cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] @@ -1948,7 +2023,7 @@ def forward( fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True @@ -2048,7 +2123,6 @@ def forward( p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] - softmax_lse_ = None out = None for i in range(cp_size + 1): if i < cp_size: @@ -2076,18 +2150,19 @@ def forward( kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data if causal: if i == 0: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - elif use_fused_attention or qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True ) - elif use_fused_attention or qkv_format == "thd": + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -2202,13 +2277,10 @@ def forward( if not _use_flash_attn_3: rng_states[i] = fa_outputs[3] elif i <= rank: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - elif use_fused_attention or qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, @@ -2217,8 +2289,12 @@ def forward( True, False, ) - elif use_fused_attention or qkv_format == "thd": + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -2338,13 +2414,10 @@ def forward( if not _use_flash_attn_3: rng_states[i] = fa_outputs[3] else: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True ) - elif use_fused_attention or qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, @@ -2353,8 +2426,12 @@ def forward( True, True, ) - elif use_fused_attention or qkv_format == "thd": + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q_half + cu_seqlens_kv_per_step[i] = cu_seqlens_kv if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...] @@ -2483,13 +2560,10 @@ def forward( if not _use_flash_attn_3: rng_states[i] = fa_outputs[3] else: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - elif use_fused_attention or qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, @@ -2498,8 +2572,12 @@ def forward( True, True, ) - elif use_fused_attention or qkv_format == "thd": + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size @@ -2615,13 +2693,9 @@ def forward( if fp8: out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) if i == 1: - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) - if causal and qkv_format != "thd": - # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = softmax_lse.view( - *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 - ) + if qkv_format == "thd": + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -2635,8 +2709,9 @@ def forward( softmax_lse_in_packed_format, ) else: - flash_attn_fwd_softmax_lse_correction( - softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1] + flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), + softmax_lse_per_step[i - 1], ) if i < cp_size: @@ -2652,14 +2727,22 @@ def forward( for i in range(cp_size): if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction( - out.view(*out_per_step[i].shape), - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - 0 if softmax_lse_in_packed_format else 2, - 2 if softmax_lse_in_packed_format else seq_dim, - ) + if i == 0: + out = flash_attn_fwd_out_correction_init( + out_per_step[0], + softmax_lse, + softmax_lse_per_step[0], + seq_dim, + ) + out = out.view(q.shape) + else: + flash_attn_fwd_out_correction( + out.view(*out_per_step[i].shape), + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + seq_dim, + ) elif qkv_format == "thd": tex.thd_out_correction( out, @@ -2672,14 +2755,12 @@ def forward( ) else: if qkv_format in ["bshd", "sbhd"]: - out_ = out.select(seq_dim, 1) - flash_attn_fwd_out_correction( - out_, + flash_attn_fwd_second_half_out_correction( + out, out_per_step[i], - softmax_lse_[..., 1, :], + softmax_lse, softmax_lse_per_step[i], - 0 if softmax_lse_in_packed_format else 2, - 2 if softmax_lse_in_packed_format else seq_dim, + seq_dim, ) elif qkv_format == "thd": tex.thd_out_correction( @@ -2701,7 +2782,7 @@ def forward( ctx.batch_size = out.shape[1] if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) @@ -2842,9 +2923,7 @@ def backward(ctx, dout): ) else: # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = softmax_lse.view( - *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 - ) + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: @@ -2932,7 +3011,9 @@ def backward(ctx, dout): if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( + cp_size_a2a, out.device + ) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, @@ -3642,7 +3723,7 @@ def backward(ctx, dout): dk, dv = dkv[0], dkv[1] if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, @@ -3692,6 +3773,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3806,9 +3888,10 @@ def forward( max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention or qkv_format == "thd": cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_q_padded = ( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size) - ) + if cu_seqlens_q_padded is not None and qkv_format == "thd": + cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + else: + cu_seqlens_q_padded = None # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) @@ -3822,7 +3905,7 @@ def forward( # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] @@ -4011,7 +4094,7 @@ def backward(ctx, dout): # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] @@ -4147,7 +4230,7 @@ def backward(ctx, dout): # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] @@ -4312,7 +4395,7 @@ def forward( fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) @@ -4383,7 +4466,7 @@ def forward( rng_state = fa_outputs[3] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) @@ -4534,7 +4617,7 @@ def backward(ctx, dout): out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) @@ -4657,7 +4740,7 @@ def backward(ctx, dout): **fa_backward_kwargs, ) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False ) @@ -4737,6 +4820,7 @@ def attn_forward_func_with_cp( fp8=False, fp8_meta=None, quantizers=None, + pad_between_seqs=False, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -4804,7 +4888,7 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers] + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) @@ -5823,6 +5907,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, + pad_between_seqs=False, ) else: @@ -6529,6 +6614,7 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + pad_between_seqs: bool = False, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -6667,6 +6753,7 @@ def forward( fp8=fp8, fp8_meta=fp8_meta, quantizers=quantizers, + pad_between_seqs=pad_between_seqs, ) else: with self.attention_dropout_ctx(): @@ -7083,6 +7170,7 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, + pad_between_seqs: Optional[bool] = None, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -7252,6 +7340,9 @@ def forward( Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. + pad_between_seqs: Optional[bool], default = `None` + If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If true, there are padding tokens between individual sequences in a packed batch. """ with self.prepare_forward( @@ -7526,13 +7617,17 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" - pad_between_seqs = ( - cu_seqlens_q_padded is not None - and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) - ) or ( - cu_seqlens_kv_padded is not None - and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) - ) + if pad_between_seqs is None: + if qkv_format == "thd": + pad_between_seqs = ( + cu_seqlens_q_padded is not None + and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) + ) or ( + cu_seqlens_kv_padded is not None + and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) + ) + else: + pad_between_seqs = False attention_params = AttentionParams( qkv_type=type(query_layer), @@ -7666,6 +7761,7 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + pad_between_seqs=pad_between_seqs, ) return self.fused_attention( query_layer, @@ -7692,6 +7788,7 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, ) from .cpu_offload import CPUOffloadEnabled @@ -8188,6 +8285,7 @@ def forward( max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, + pad_between_seqs: Optional[bool] = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """ Forward propagation for MultiheadAttention layer. @@ -8266,6 +8364,9 @@ def forward( Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. + pad_between_seqs: Optional[bool], default = `None` + If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If true, there are padding tokens between individual sequences in a packed batch. """ # hidden_states: [sq, b, h] @@ -8523,6 +8624,7 @@ def forward( alibi_slopes=alibi_slopes, fast_zero_fill=fast_zero_fill, inference_params=inference_params, + pad_between_seqs=pad_between_seqs, ) # =================== diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 97b1361163..fbc787d6d2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -546,6 +546,7 @@ def forward( max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, + pad_between_seqs: Optional[bool] = None, ) -> torch.Tensor: """ Transformer Layer: attention block and a feedforward network (MLP) @@ -637,6 +638,9 @@ def forward( inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order to efficiently calculate and store the context during inference. + pad_between_seqs: Optional[bool], default = `None` + If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If true, there are padding tokens between individual sequences in a packed batch. """ if self_attn_mask_type is None: @@ -697,6 +701,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, + pad_between_seqs=pad_between_seqs, ) if self.apply_residual_connection_post_layernorm and not self.output_layernorm: From 77fa1e5967f34dc62f148ee8a43c642603af7389 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Fri, 7 Mar 2025 23:58:03 -0800 Subject: [PATCH 191/239] [PyTorch] Enabling Per-Tensor Current Scaling Recipe (#1471) * check in per-tensor current scaling full recipe Signed-off-by: zhongboz [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: zhongboz setup basics of current scaling quantizer in python level Signed-off-by: zhongboz [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: zhongboz add test case for current scaling dequantize Signed-off-by: zhongboz [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: zhongboz finish linear layer fwd bwd test, determined error with bf16 Signed-off-by: zhongboz [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: zhongboz achieved zero tolerance for Linear by specify gemm use_split_accumulator config Signed-off-by: zhongboz enable layernormlinear with current scaling, pass bitwise test Signed-off-by: zhongboz refactor test case code Signed-off-by: zhongboz make current scaling quantizers distrbuted, pass distributed linear&layernormlinear tests Signed-off-by: zhongboz bug fix: use cached fp8 recipe in backward Signed-off-by: zhongboz fix layernorm_mlp with current scaling, fix activation_helper with current scaling Signed-off-by: zhongboz support detailed numerical settings from recipe to quantization kernel Signed-off-by: zhongboz resolving MR comments Signed-off-by: zhongboz recipe naming Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve mr comments, remove IS_CURRENT_SCALING template from kernels Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve mr comments, make current scaling c++ test cases Signed-off-by: zhongboz * add current scaling to test_numerics.py, skip act recomp and grouped linear Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add benchmark for quantizer Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add benchmarks for linear layer Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bug fix, typo Signed-off-by: zhongboz * resolve more mr comments Signed-off-by: zhongboz * avoid potential race condition by not using from_blob to construct amax tensor in C++ Signed-off-by: zhongboz * resolve more comments Signed-off-by: zhongboz * Debug linter warnings and license check Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug import error in FP8 tensor test Signed-off-by: Tim Moon * Debug compilation error with CUDA 12.1 for Turing Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve mr comments, fix activation cast fusion Signed-off-by: zhongboz * resolve comments, add NVTEQuantizationParams for compute scale Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove is_current_scaling check totally from common folder Signed-off-by: zhongboz * remove benchmarks, will contribute in another repo Signed-off-by: zhongboz * adjust cs default recipe config Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adjust comments in test Signed-off-by: zhongboz * Remove current scaling mode from core lib Signed-off-by: Tim Moon * Refactor current-scaling-specific logic in core C++ lib Move amax and scale update functions out of casting functions, and put into dedicated current-scaling source file. Add general API for accessing quantization config object. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing header in C++ tests Signed-off-by: Tim Moon * Disable test config with FP8 transpose on Blackwell Signed-off-by: Tim Moon * Fix compilation error in C++ test Signed-off-by: Tim Moon --------- Signed-off-by: zhongboz Signed-off-by: Tim Moon Co-authored-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon --- .gitignore | 1 + qa/L0_cppunittest/test.sh | 0 tests/cpp/operator/CMakeLists.txt | 2 + tests/cpp/operator/test_cast.cu | 4 + .../cpp/operator/test_cast_current_scaling.cu | 214 +++++ tests/cpp/operator/test_cast_transpose.cu | 4 + .../test_cast_transpose_current_scaling.cu | 210 +++++ tests/cpp/test_common.cu | 28 +- tests/cpp/test_common.h | 4 + tests/pytorch/distributed/run_numerics.py | 125 ++- tests/pytorch/distributed/test_numerics.py | 4 +- tests/pytorch/references/ref_per_tensor_cs.py | 105 +++ .../test_float8_current_scaling_exact.py | 802 ++++++++++++++++++ tests/pytorch/test_float8tensor.py | 120 ++- tests/pytorch/test_numerics.py | 7 + tests/pytorch/test_recipe.py | 1 + transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/common.h | 49 +- .../include/transformer_engine/recipe.h | 23 + .../transformer_engine/transformer_engine.h | 112 ++- transformer_engine/common/recipe/__init__.py | 94 ++ .../common/recipe/current_scaling.cu | 237 ++++++ .../common/transformer_engine.cpp | 80 +- .../common/transpose/cast_transpose.cu | 5 +- .../common/util/cast_kernels.cuh | 26 +- transformer_engine/common/utils.cuh | 2 +- transformer_engine/pytorch/constants.py | 10 + transformer_engine/pytorch/csrc/common.cpp | 11 +- transformer_engine/pytorch/csrc/common.h | 26 + .../pytorch/csrc/extensions/activation.cpp | 31 +- .../pytorch/csrc/extensions/cast.cpp | 23 + .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/csrc/extensions/quantizer.cpp | 117 +++ .../pytorch/csrc/extensions/swizzle.cpp | 2 +- .../csrc/extensions/type_converters.cpp | 3 +- transformer_engine/pytorch/csrc/pybind.h | 15 +- transformer_engine/pytorch/distributed.py | 18 +- transformer_engine/pytorch/fp8.py | 70 +- transformer_engine/pytorch/module/base.py | 14 +- .../pytorch/module/grouped_linear.py | 3 + .../pytorch/module/layernorm_linear.py | 99 ++- .../pytorch/module/layernorm_mlp.py | 206 ++++- transformer_engine/pytorch/module/linear.py | 102 ++- .../pytorch/tensor/float8_tensor.py | 171 +++- 44 files changed, 3056 insertions(+), 128 deletions(-) mode change 100644 => 100755 qa/L0_cppunittest/test.sh create mode 100644 tests/cpp/operator/test_cast_current_scaling.cu create mode 100644 tests/cpp/operator/test_cast_transpose_current_scaling.cu create mode 100644 tests/pytorch/references/ref_per_tensor_cs.py create mode 100644 tests/pytorch/test_float8_current_scaling_exact.py create mode 100644 transformer_engine/common/recipe/current_scaling.cu diff --git a/.gitignore b/.gitignore index f491b21f43..850b352d31 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ downloads/ .pytest_cache/ compile_commands.json .nfs +tensor_dumps/ \ No newline at end of file diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh old mode 100644 new mode 100755 diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index ce78fcaae2..6785dbf6f4 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -4,6 +4,7 @@ add_executable(test_operator test_cast.cu + test_cast_current_scaling.cu test_cast_dbias.cu test_cast_dbias_dgelu.cu test_cast_gated_swiglu.cu @@ -13,6 +14,7 @@ add_executable(test_operator test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu + test_cast_transpose_current_scaling.cu test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu index f57d1f035d..81c975b0a8 100644 --- a/tests/cpp/operator/test_cast.cu +++ b/tests/cpp/operator/test_cast.cu @@ -35,6 +35,8 @@ void compute_ref(const InputType *data, OutputType *output_c, *amax = current_max; } + +// delayed tensor scaling test template void performTest(const std::vector& shape) { using namespace test; @@ -55,6 +57,7 @@ void performTest(const std::vector& shape) { nvte_quantize(input.data(), output_c.data(), 0); float ref_amax; + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), full_size, &ref_amax, output_c.scale()); @@ -105,6 +108,7 @@ TEST_P(CastTestSuite, TestCast) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // delayed tensor scaling performTest(size); ); ); diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu new file mode 100644 index 0000000000..18325d6daf --- /dev/null +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, + const size_t size, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i] = OutputType(scale * current); + } +} + + +template +void compute_amax_scale_ref(const InputType *data, + const size_t size, + float *amax_ptr, float *scale_ptr, float* scale_inv_ptr, + float max_fp8, float epsilon) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + } + *amax_ptr = current_max; + + // compute scale from amax + float clamp_amax = current_max; + if (current_max <= epsilon){ + clamp_amax = epsilon; + } + + float scale = 1.f; + float scale_inv = 1.f; + + if (isinf(clamp_amax) || clamp_amax == 0.f) { + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; + return; + } + + // use ieee_div in CPU + scale = max_fp8 / clamp_amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + scale_inv = 1.0f / scale; + + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; +} + +// current tensor scaling test +template +void performTest(const std::vector& shape) { + using namespace test; + + const size_t full_size = product(shape); + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + bool is_out_fp8 = isFp8Type(otype); + + // find out max fp8 value + float max_fp8; + if (is_out_fp8){ + switch (otype) { + case DType::kFloat8E5M2: { + max_fp8 = Quantized_Limits::max(); + } break; + case DType::kFloat8E4M3: { + max_fp8 = Quantized_Limits::max(); + } break; + default: + NVTE_ERROR("Invalid type."); + } + } + + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype, true, false); + + std::unique_ptr ref_output_c = std::make_unique(full_size); + + fillUniform(&input); + + // compute amax + float amax_to_check = 0.0f; + if (is_out_fp8){ + nvte_compute_amax(input.data(), output_c.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(output_c.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling + amax_to_check = output_c.amax(); + output_c.set_tensor_amax_nullptr(); + } + nvte_quantize(input.data(), output_c.data(), 0); + + float ref_amax; + float ref_scale; + float ref_scale_inv; + if (is_out_fp8){ + compute_amax_scale_ref(input.rowwise_cpu_dptr(), + full_size, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f); + } + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + full_size, nullptr, is_out_fp8 ? output_c.scale() : 1.0f ); + + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); + compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); + compareResults("scale", output_c.scale(), ref_scale, 0.0f, rtol_fp32); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, 0.0f, rtol); +} + +std::vector> test_cases = { + {16}, + {16000}, + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; +} // namespace + +class CastCSTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastCSTestSuite, TestCastCS) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // current tensor scaling + performTest(size); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastCSTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 830682eec3..380ae96190 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -38,6 +38,8 @@ void compute_ref(const InputType *data, OutputType *output_c, OutputType *output *amax = current_max; } + +// delayed tensor scaling test template void performTest(const size_t N, const size_t H) { using namespace test; @@ -75,6 +77,7 @@ void performTest(const size_t N, const size_t H) { compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } + std::vector> test_cases = {{2048, 12288}, {768, 1024}, {256, 65536}, @@ -101,6 +104,7 @@ TEST_P(CTTestSuite, TestCastTranspose) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // delayed tensor scaling performTest(size.first, size.second); ); ); diff --git a/tests/cpp/operator/test_cast_transpose_current_scaling.cu b/tests/cpp/operator/test_cast_transpose_current_scaling.cu new file mode 100644 index 0000000000..267970b34f --- /dev/null +++ b/tests/cpp/operator/test_cast_transpose_current_scaling.cu @@ -0,0 +1,210 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t, + const size_t N, const size_t H, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i * H + j] = OutputType(scale * current); + output_t[j * N + i] = OutputType(scale * current); + } + } +} + +template +void compute_amax_scale_ref(const InputType *data, + const size_t N, const size_t H, + float *amax_ptr, float *scale_ptr, float* scale_inv_ptr, + float max_fp8, float epsilon) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + current_max = fmaxf(current_max, fabsf(current)); + } + } + *amax_ptr = current_max; + + // compute scale from amax + float clamp_amax = current_max; + if (current_max <= epsilon){ + clamp_amax = epsilon; + } + + float scale = 1.f; + float scale_inv = 1.f; + + if (isinf(clamp_amax) || clamp_amax == 0.f) { + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; + return; + } + + // use ieee_div in CPU + scale = max_fp8 / clamp_amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + scale_inv = 1.0f / scale; + + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; +} + +// current tensor scaling test +template +void performTest(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + bool is_out_fp8 = isFp8Type(otype); + + // find out max fp8 value + float max_fp8; + if (is_out_fp8){ + switch (otype) { + case DType::kFloat8E5M2: { + max_fp8 = Quantized_Limits::max(); + } break; + case DType::kFloat8E4M3: { + max_fp8 = Quantized_Limits::max(); + } break; + default: + NVTE_ERROR("Invalid type."); + } + } + + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); + + std::unique_ptr ref_output_c = std::make_unique(N * H); + std::unique_ptr ref_output_t = std::make_unique(N * H); + + fillUniform(&input); + + // compute amax + float amax_to_check = 0.0f; + if (is_out_fp8){ + nvte_compute_amax(input.data(), output.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(output.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling + amax_to_check = output.amax(); + output.set_tensor_amax_nullptr(); + } + nvte_quantize(input.data(), output.data(), 0); + + float ref_amax; + float ref_scale; + float ref_scale_inv; + if (is_out_fp8){ + compute_amax_scale_ref(input.rowwise_cpu_dptr(), + N, H, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f); + } + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + ref_output_t.get(), N, H, nullptr, + is_out_fp8 ? output.scale() : 1.0f ); + + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); + compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); + compareResults("scale", output.scale(), ref_scale, 0.0f, rtol_fp32); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32); + compareResults("scale_inv_columnwise", output.columnwise_cpu_scale_inv_ptr()[0], ref_scale_inv, 0.0f, rtol_fp32); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output, ref_output_c.get(), true, 0.0f, rtol); + compareResults("output_t", output, ref_output_t.get(), false, 0.0f, rtol); +} + +std::vector> test_cases = {{2048, 12288}, + {768, 1024}, + {256, 65536}, + {65536, 128}, + {256, 256}, + {120, 2080}, + {8, 8}, + {1, 3221}, // Prime 456 + {2333, 1}, // Prime 345 + {1481, 677}}; // Primes 234, 123 +} // namespace + +class CTCSTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CTCSTestSuite, TestCastTransposeCS) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // current tensor scaling + performTest(size.first, size.second); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CTCSTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::ValuesIn(test::all_fp_types), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + std::to_string(std::get<2>(info.param).first) + "X" + + std::to_string(std::get<2>(info.param).second); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ec4a9bdbb7..24aff83d8a 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -103,10 +103,6 @@ size_t DIVUP(const size_t &x, const size_t &y){ return (((x) + ((y)-1)) / (y)); } -inline bool is_tensor_scaling(const NVTEScalingMode &mode) { - return mode == NVTE_DELAYED_TENSOR_SCALING; -} - struct scale_inv_meta { std::vector shape; DType type; @@ -233,7 +229,7 @@ Tensor::Tensor(const std::string& name, tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); if (isFp8Type(type)) { - if (is_tensor_scaling(scaling_mode)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) @@ -296,11 +292,13 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (isFp8Type(dtype())) { - if (is_tensor_scaling(tensor_.scaling_mode())) { - cudaMemcpy(amax_cpu_data_.get(), - tensor_.amax(), - sizeof(float), - cudaMemcpyDeviceToHost); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.amax() != nullptr){ + cudaMemcpy(amax_cpu_data_.get(), + tensor_.amax(), + sizeof(float), + cudaMemcpyDeviceToHost); + } cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float), @@ -336,9 +334,11 @@ void Tensor::from_cpu() const { cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } if (isFp8Type(dtype())) { - if (is_tensor_scaling(tensor_.scaling_mode())) { - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.amax() != nullptr){ + cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + } cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } @@ -361,7 +361,7 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (is_tensor_scaling(tensor_.scaling_mode())) { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; from_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index dc515ccb8e..4352056ddb 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -256,6 +256,10 @@ class Tensor { return columnwise_; } + void set_tensor_amax_nullptr(){ + tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 2d301e3151..e2e78b72b1 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -14,13 +14,15 @@ import torch from torch import nn import torch.distributed as dist - +import transformer_engine_torch as tex from transformer_engine.common.recipe import ( MXFP8BlockScaling, DelayedScaling, + Float8CurrentScaling, Format, Recipe, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -45,6 +47,8 @@ def quantization_recipe() -> Recipe: ) if QUANTIZATION == "mxfp8": return MXFP8BlockScaling() + if QUANTIZATION == "fp8_cs": + return Float8CurrentScaling() return te.fp8.get_default_fp8_recipe() @@ -88,6 +92,7 @@ def main(argv=None, namespace=None): HIDDEN_SIZE = 128 test_dict = [ + test_quantizer, test_linear, test_layernorm, test_layernorm_linear, @@ -152,7 +157,12 @@ def dist_print(msg, src=None, end="\n", error=False): def _get_tolerances(dtype): - if QUANTIZATION is not None: + # loose tolerances for fp8_cs because of sequence parallel & amax reduction + # so that each rank has a different scale_inv for computing Y when we have + # row parallel & sequence parallel, because we do the all_gather in backward pass + if QUANTIZATION == "fp8_cs": + return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} if dtype == torch.float16: @@ -293,6 +303,98 @@ def _alloc_main_grad(model_single_node, model_distributed): param.main_grad = torch.zeros_like(param, dtype=torch.float32) +############################################### +# Quantizer # +############################################### +def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): + """ + quantizer is the reference quantizer on a single GPU. + quantizer_dist is the distributed quantizer to be tested on multiple GPUs. + """ + if quantizer_class == Float8CurrentScalingQuantizer: + quantizer_dist = quantizer_class( + fp8_dtype=fp8_dtype, + device=device, + with_amax_reduction=True, + amax_reduction_group=tp_group, + amax_reduction_size=tp_size, + ) + quantizer = quantizer_class( + fp8_dtype=fp8_dtype, + device=device, + with_amax_reduction=False, + ) + return quantizer, quantizer_dist + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_class}") + + +def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad).cuda()) + return out + + +@run_distributed_test() +def _test_quantizer(input_dtype, fp8_dtype): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + fp8_dtype (tex.DType): The data type of the fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + x_hp_cpu[M - 1, N - 1] = 1e4 + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + Float8CurrentScalingQuantizer, fp8_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the input + if WORLD_RANK == 0: + x_fp8_single = quantizer(x_hp_rank0) + + # multi-GPU quantizer + x_fp8_dist = quantizer_dist(x_hp_local_rank) + + # check scale_inv with zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close( + x_fp8_single._scale_inv, x_fp8_dist._scale_inv, rtol=0.0, atol=0.0 + ) + + +def test_quantizer(): + """ + Run quantizer tests with various configurations. + Currently only check fp8_cs because it needs to do amax reduction in the quantizer. + """ + # skip this test for other quantization schemes + if QUANTIZATION != "fp8_cs": + return + + input_dtypes = [torch.float32, torch.bfloat16] + fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + + for input_dtype in input_dtypes: + for fp8_dtype in fp8_dtypes: + _test_quantizer(input_dtype, fp8_dtype) + + ############################################ # Linear # ############################################ @@ -339,6 +441,11 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) ) input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[BATCH_SIZE - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed, dim=0).detach() else: input_distributed = input_single_node.clone() @@ -501,6 +608,12 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs # Duplicate input for sequence parallelism input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + # make the last element of the input a large value to test the amax reduction on purpose + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed).detach() else: input_distributed = input_single_node.clone() @@ -599,6 +712,12 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg # Duplicate input for sequence parallelism input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + # make the last element of the input a large value to test the amax reduction on purpose + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed).detach() else: input_distributed = input_single_node.clone() @@ -651,6 +770,7 @@ def test_layernorm_mlp(): {"return_bias": True}, {"return_layernorm_output": True}, ] + for kwargs in kwargs_list: for set_parallel_mode in [True]: for sequence_parallel in [False, True]: @@ -745,6 +865,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True}, {"activation": "relu"}, ] + for kwargs in kwargs_list: for sequence_parallel in [False, True]: _test_transformer_layer_parallel(sequence_parallel, **kwargs) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 7be9cd01ae..b4e2b680b3 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -48,10 +48,12 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8"]) +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) + if quantization == "fp8_cs" and not fp8_available: + pytest.skip(fp8_available) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) _run_test(quantization) diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py new file mode 100644 index 0000000000..1895b31d78 --- /dev/null +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType_To_Torch + + +# compute amax and scale +def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): + x_fp32 = x.to(torch.float32) + amax = torch.amax(torch.abs(x_fp32)).view(1) + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + # option1: set scale to fp32 max when scale is inf + scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale) + # option2: when scale is inf, set scale to 1 + scale = torch.where(scale == torch.inf, 1.0, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + # TODO: If/when adding a URM option an option is to cap to 126 + # rather than allowing the full range of FP32 (2 - 2^23) x 2^127 + # addresses cases where adding a mantissa overflows into inf scales. + # Not necessary currently without additional scale smudging options. + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax + + +def _multi_dim_transpose(tensor): + # Get the number of dimensions + dims = list(range(len(tensor.shape))) + + if len(dims) <= 1: + return tensor + + # circular shift of shapes + new_order = [] + new_order.append(dims[-1]) + for i in range(len(dims) - 1): + new_order.append(dims[i]) + + # Permute the tensor according to the new order + output_tensor = tensor.permute(new_order).contiguous() + + return output_tensor + + +# current scaling reference quantization +def ref_per_tensor_cs_cast( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + return_transpose: bool = False, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> torch.Tensor: + + quant_dtype_torch = TE_DType_To_Torch[fp8_dtype] + scale, scale_inv, _ = _ref_compute_amax_scale( + tensor, + quant_dtype_torch, + amax_epsilon, + force_pow_2_scales, + ) + + qx = (tensor.float() * scale).to(quant_dtype_torch) + sx = scale_inv + qx_t = None + sx_t = None + + if tensor.shape == torch.Size([]): + qx = qx.view([]) + + if return_transpose: + qx_t = _multi_dim_transpose(qx) + sx_t = sx + return qx, sx, qx_t, sx_t diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py new file mode 100644 index 0000000000..9741b1258c --- /dev/null +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -0,0 +1,802 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pathlib +import os +import torch +import pytest + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8CurrentScaling +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype + + +# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) + + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +class GetRecipes: + + @staticmethod + def none(): + return None + + @staticmethod + def fp8_per_tensor_current_scaling_default(): + # return default configs + return Float8CurrentScaling() + + +# base class for validating current_scaling x linear layer +class TestFP8RecipeLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + return torch.mean(torch.abs((a - b) / b)) + + @staticmethod + def _load_golden_tensor_values(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias): + recipe = get_recipe() + batch_size, hidden_size, out_size = dims + fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) + + # Expected tensor names based on the naming template + scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example + "ScalingType.PER_TENSOR" + ) + current_seed = torch.initial_seed() # Get the current seed + + expected_tensor_names = { + "y": f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "dgrad": f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "wgrad": f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "bgrad": f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + } + + if not use_bias: + expected_tensor_names.pop("bgrad") + + # Check if all expected tensors are in the tensor dumps directory + tensor_map = {} + for tensor_key, tensor_name in expected_tensor_names.items(): + tensor_path = dump_dir / tensor_name + if not os.path.exists(tensor_path): + print(f"Missing tensor: {tensor_name}") + return None + + # Load the tensor + tensor_map[tensor_key] = torch.load(tensor_path) + return tensor_map + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + def compare_recipe( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed, + dtype, + y_error=0.0, + dgrad_error=0.0, + wgrad_error=0.0, + bgrad_error=0.0, + recipe1_golden_tensors=None, + recipe2_golden_tensors=None, + ): + x, w, bias, gradient = self._prepare_data( + batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype + ) + + # recipe1 + using_fp8_recipe = recipe1 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) + else: + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) + + # recipe2 + using_fp8_recipe = recipe2 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) + else: + y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) + + # Compare results (mean abs relative error) + assert ( + self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error + ), "y and y_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error + ), "dgrad and dgrad_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error + ), "wgrad and wgrad_ref has too large mean abs relative error" + if use_bias: + assert ( + self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error + ), "bgrad and bgrad_ref has too large mean abs relative error" + + # enforce zero tolerance check when we can find golden tensor value dump + if recipe2_golden_tensors is not None: + torch.testing.assert_close( + y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0 + ) + torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0) + torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0) + if use_bias: + torch.testing.assert_close( + bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0 + ) + + +class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): + + @staticmethod + def _check_golden_tensor_dumps( + dump_dir, get_recipe, dims, input_dtype, use_bias, normalization + ): + recipe = get_recipe() + batch_size, hidden_size, out_size = dims + fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) + + # Expected tensor names based on the naming template + scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example + "ScalingType.PER_TENSOR" + ) + current_seed = torch.initial_seed() # Get the current seed + + expected_tensor_names = { + "y": f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "ln_out": f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "dgrad": f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "wgrad": f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "bgrad": f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + } + + if not use_bias: + expected_tensor_names.pop("bgrad") + + # Check if all expected tensors are in the tensor dumps directory + tensor_map = {} + for tensor_key, tensor_name in expected_tensor_names.items(): + tensor_path = dump_dir / tensor_name + if not os.path.exists(tensor_path): + print(f"Missing tensor: {tensor_name}") + return None + + # Load the tensor + tensor_map[tensor_key] = torch.load(tensor_path) + return tensor_map + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + def compare_recipe( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed, + dtype, + y_error=0.0, + ln_out_error=0.0, + dgrad_error=0.0, + wgrad_error=0.0, + bgrad_error=0.0, + normalization="LayerNorm", + LayerNormLinearClass1=te.LayerNormLinear, + LayerNormLinearClass2=te.LayerNormLinear, + recipe1_golden_tensors=None, + recipe2_golden_tensors=None, + ): + x, w, bias, gradient = self._prepare_data( + batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype + ) + + # recipe1 + using_fp8_recipe = recipe1 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass1, + ) + else: + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass1, + ) + + # recipe2 + using_fp8_recipe = recipe2 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass2, + ) + else: + y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass2, + ) + + # Compare results (mean abs relative error) + assert ( + self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error + ), "y and y_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(ln_out, ln_out_ref).item() < ln_out_error + ), "ln_out and ln_out_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error + ), "dgrad and dgrad_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error + ), "wgrad and wgrad_ref has too large mean abs relative error" + if use_bias: + assert ( + self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error + ), "bgrad and bgrad_ref has too large mean abs relative error" + + # enforce zero tolerance check when we can find golden tensor value dump + if recipe2_golden_tensors is not None: + torch.testing.assert_close( + y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0 + ) + torch.testing.assert_close(ln_out, recipe2_golden_tensors["ln_out"], atol=0.0, rtol=0.0) + torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0) + torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0) + if use_bias: + torch.testing.assert_close( + bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0 + ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 9d01527ac5..42600e3099 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -12,9 +12,17 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8Tensor, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch +from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported import transformer_engine_torch as tex +from references.ref_per_tensor_cs import ref_per_tensor_cs_cast + # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] # TE FP8 dtypes @@ -42,6 +50,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# delayed scaling def to_float8( tensor: torch.Tensor, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, @@ -56,6 +65,29 @@ def to_float8( return quantizer(tensor.cuda()) +# current scaling +def to_float8_CS( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + return_transpose: bool = False, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> Float8Tensor: + """Cast tensor to FP8""" + tensor = tensor.cuda() + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=tensor.device, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + if return_transpose: + quantizer.set_usage(rowwise=True, columnwise=True) + else: + quantizer.set_usage(rowwise=True, columnwise=False) + return quantizer(tensor) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFloat8Tensor: @@ -310,3 +342,89 @@ def test_set_data(self): assert x.size() == y.size() assert x.dtype == y.dtype assert x.device == y.device + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestCurrentScalingFloat8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize( + "dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]] + ) + @pytest.mark.parametrize("return_transpose", [True, False], ids=str) + @pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str) + @pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str) + def test_quantize( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + dims: DimsType, + return_transpose: bool, + force_pow_2_scales: bool, + amax_epsilon: float, + ) -> None: + """Check numerical error when casting to FP8""" + + # Skip invalid configurations + if non_tn_fp8_gemm_supported() and return_transpose: + pytest.skip("FP8 transpose is neither needed nor supported on current system") + + # Initialize random high precision data + device = "cuda" + x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1 + + # Cast to FP8 and back + x_fp8 = to_float8_CS( + x_hp, + fp8_dtype=fp8_dtype, + return_transpose=return_transpose, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + # get reference implementation of current scaling + x_fp8_ref, sx_ref, x_fp8_t_ref, _ = ref_per_tensor_cs_cast( + x_hp, + fp8_dtype=fp8_dtype, + return_transpose=return_transpose, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + torch.testing.assert_close(x_fp8._data, x_fp8_ref.view(torch.uint8), atol=0.0, rtol=0.0) + torch.testing.assert_close(x_fp8._scale_inv, sx_ref, atol=0.0, rtol=0.0) + if return_transpose: + torch.testing.assert_close( + x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0 + ) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]]) + def test_quantize_dequantize( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType + ) -> None: + """Check numerical error when casting to FP8 and back""" + + # Initialize random high precision data + device = "cuda" + x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1 + + # Cast to FP8 and back + x_fp8 = to_float8_CS(x_hp, fp8_dtype=fp8_dtype) + x_fp8_dequantized = x_fp8.dequantize() + + # Check results + torch.testing.assert_close(x_fp8_dequantized, x_hp, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a72ba097a1..5bec7f7c7f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -100,6 +100,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), ] @@ -670,6 +671,8 @@ def test_gpt_full_activation_recompute( pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for full recompute.") config = model_configs[model] @@ -1482,6 +1485,8 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_mxfp8) if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for grouped linear.") + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: @@ -1675,6 +1680,8 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_mxfp8) if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for grouped linear.") + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index dcac5f1500..30989bec61 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -23,6 +23,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# FP8 per tensor delayed scaling @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFP8Recipe: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 0a2abb6e4e..007618ad57 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -86,6 +86,7 @@ list(APPEND transformer_engine_SOURCES fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu + recipe/current_scaling.cu recipe/delayed_scaling.cu comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 46eb248156..4163505db6 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -29,6 +29,18 @@ namespace transformer_engine { +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); } + +inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -132,7 +144,7 @@ struct Tensor { if (!has_data() && has_columnwise_data()) { const auto &data_shape = columnwise_data.shape; if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (is_tensor_scaling(scaling_mode)) { return product(data_shape, 1, data_shape.size()); } else { return product(data_shape, 0, data_shape.size() - 1); @@ -152,7 +164,7 @@ struct Tensor { if (!has_data() && has_columnwise_data()) { const auto &data_shape = columnwise_data.shape; if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (is_tensor_scaling(scaling_mode)) { return data_shape.front(); } else { return data_shape.back(); @@ -164,6 +176,16 @@ struct Tensor { } }; +struct QuantizationConfig { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0f; + + static constexpr size_t attr_sizes[] = { + sizeof(bool), // force_pow_2_scales + sizeof(float) // amax_epsilon + }; +}; + template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -396,6 +418,15 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ + if (CONDITION) { \ + constexpr bool FLAG = true; \ + { __VA_ARGS__ } \ + } else { \ + constexpr bool FLAG = false; \ + { __VA_ARGS__ } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { @@ -449,20 +480,6 @@ bool is_fp8_dtype(const DType t); std::string to_string(const DType type); std::string to_string(const NVTEScalingMode &type); -inline bool is_tensor_scaling(const NVTEScalingMode &mode) { - return mode == NVTE_DELAYED_TENSOR_SCALING; -} - -inline bool is_block_scaling(const NVTEScalingMode &mode) { - return mode != NVTE_DELAYED_TENSOR_SCALING; -} - -inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { - return is_tensor_scaling(mode); -} - -inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } - /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index b30a6e1338..44614bbe6b 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -73,6 +73,29 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); +/*! \brief Compute an FP8 tensor's amax. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Update an FP8 tensor's scale based on its amax. + * + * This is only supported for FP8 tensors with per-tensor scaling. + * Options are primarily intended for FP8 current-scaling recipes. + * + * \param[in,out] output FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e393dbffc4..e91f3c4836 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -68,11 +68,14 @@ enum NVTETensorParam { }; /*! \enum NVTEScalingMode - * \brief Granularity of scaling: + * \brief Tensor data format. */ enum NVTEScalingMode { - /*! Single scale per tensor, computed in delayed manner. - Used also for high precision data, without scaling */ + /*! Either an unquantized tensor or an FP8 tensor with per-tensor scaling + * + * Not necessary used for delayed tensor scaling. The unintuitive + * name reflects legacy usage. + */ NVTE_DELAYED_TENSOR_SCALING = 0, /*! Single scale per block of 32 elements consecutive in either rowwise or columnwise direction */ @@ -266,6 +269,57 @@ void nvte_tensor_pack_create(NVTETensorPack *pack); */ void nvte_tensor_pack_destroy(NVTETensorPack *pack); +/*! \brief Configuration for tensor quantization. */ +typedef void *NVTEQuantizationConfig; + +/*! \enum NVTEQuantizationConfigAttribute + * \brief Type of option for tensor quantization. + */ +enum NVTEQuantizationConfigAttribute { + /*! Whether to force power of 2 scales */ + kNVTEQuantizationConfigForcePow2Scales = 0, + /*! Small value to add to amax for numerical stability */ + kNVTEQuantizationConfigAmaxEpsilon = 1, + kNVTEQuantizationConfigNumAttributes +}; + +/*! \brief Create a new quantization config. + * \return A new quantization config. + */ +NVTEQuantizationConfig nvte_create_quantization_config(); + +/*! \brief Query an option in quantization config. + * + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in quantization config. + * + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, const void *buf, + size_t size_in_bytes); + +/*! \brief Destroy a quantization config. + * + * \param[in] config Config to be destroyed. + */ +void nvte_destroy_quantization_config(NVTEQuantizationConfig config); + #ifdef __cplusplus } // extern "C" @@ -610,6 +664,58 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \struct QuantizationConfigWrapper + * \brief C++ wrapper for NVTEQuantizationConfigWrapper. + */ +class QuantizationConfigWrapper { + public: + QuantizationConfigWrapper() : config_{nvte_create_quantization_config()} {} + + QuantizationConfigWrapper(const QuantizationConfigWrapper &) = delete; + QuantizationConfigWrapper &operator=(const QuantizationConfigWrapper &) = delete; + + QuantizationConfigWrapper(QuantizationConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + QuantizationConfigWrapper &operator=(QuantizationConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_quantization_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~QuantizationConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_quantization_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEQuantizationConfig. + * + * \return NVTEQuantizationConfig held by this QuantizationConfigWrapper. + */ + operator NVTEQuantizationConfig() const noexcept { return config_; } + + /*! \brief Set whether to force power of 2 scales */ + void set_force_pow_2_scales(bool force_pow_2_scales) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, + &force_pow_2_scales, sizeof(bool)); + } + + /*! \brief Set small value to add to amax */ + void set_amax_epsilon(float amax_epsilon) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEpsilon, + &amax_epsilon, sizeof(float)); + } + + private: + /*! \brief Wrapped NVTEQuantizationConfig. */ + NVTEQuantizationConfig config_ = nullptr; +}; + } // namespace transformer_engine #endif // __cplusplus diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 0bce83d98f..937383d5ec 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -39,6 +39,27 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) +@dataclass(frozen=True) +class MMParams: + """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) + apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, + so only turn it on for certain gemms + """ + + use_split_accumulator: bool = True + + +@dataclass(frozen=True) +class QParams: + """Quantization parameters. + power_2_scale: use power of 2 scale parameter + amax_epsilon: optional minimum value of abs max + """ + + power_2_scale: bool = False + amax_epsilon: float = 0.0 + + class Recipe: """ Base recipe class. @@ -52,6 +73,10 @@ def delayed(self): """Whether the given recipe is delayed scaling.""" return isinstance(self, DelayedScaling) + def float8_current_scaling(self): + """Whether the given recipe is (per-tensor) current scaling.""" + return isinstance(self, Float8CurrentScaling) + @dataclass() class DelayedScaling(Recipe): @@ -161,6 +186,75 @@ def __repr__(self) -> str: ) +@dataclass() +class Float8CurrentScaling(Recipe): + """ + Use the per-tensor current scaling factor strategy. + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of gradient tensor dY + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_dpa: bool, default = `False` + Whether to enable FP8 dot product attention (DPA). When the model is placed in an + `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + inputs from higher precision to FP8, performs attention in FP8, and casts tensors + back to higher precision as outputs. FP8 DPA currently is only supported in the + `FusedAttention` backend. + fp8_mha: bool, default = `False` + Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting + operations mentioned above at the DPA boundaries. Currently only standard MHA modules + i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When + `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as + `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. + When `fp8_mha = True, fp8_dpa = True`, it becomes + `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + + Notes + ----- + * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are + subject to change in future Transformer Engine releases. + """ + + fp8_format: Format = Format.HYBRID + fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) + + @dataclass() class MXFP8BlockScaling(Recipe): """ diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu new file mode 100644 index 0000000000..3a25d71a3b --- /dev/null +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -0,0 +1,237 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/vectorized_pointwise.h" + +namespace transformer_engine { +namespace { + +constexpr int amax_kernel_threads = 512; + +template +__launch_bounds__(amax_kernel_threads) __global__ + void amax_kernel(const InputType *input, float *amax, const size_t N, + const size_t num_aligned_elements) { + VectorizedLoader loader(input, N); + InputType max = 0.f; + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val = static_cast(loader.separate()[i]); + __builtin_assume(max >= InputType{0.f}); + if constexpr (std::is_same_v) { +#if __CUDA_ARCH__ >= 800 + max = __hmax(__habs(val), max); +#else // Turing + max = static_cast<__nv_bfloat16>( + fmaxf(fabsf(static_cast(val)), static_cast(max))); +#endif + } else if constexpr (std::is_same_v) { + max = __hmax(__habs(val), max); + } else { + max = fmaxf(fabsf(val), max); + } + } + } + + // Reduce amax over block + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + atomicMaxFloat(amax, max); + } +} + +template +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { + // Zero out amax so we can update with atomic max + cudaMemsetAsync(amax, 0, sizeof(float), stream); + + // Return immediately if tensor is empty + if (N == 0) { + return; + } + + // Figure out alignment + auto align = CheckAlignment(N, nvec, input); + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); + + // Figure out CUDA blocks + constexpr size_t threads = amax_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + // Launch kernel + switch (align) { + case Alignment::SAME_ALIGNED: + amax_kernel + <<>>(input, amax, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + amax_kernel + <<>>(input, amax, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // This case is a logic error, since there is only one pointer (input) + // in the alignment check. Still safe to process without vectorization. + amax_kernel<1, true, InputType><<>>(input, amax, N, N); + break; + } + } + + // Check results + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace +} // namespace transformer_engine + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + using namespace transformer_engine; + + // Check input tensor + NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); + const auto &input = *reinterpret_cast(input_); + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor for amax computation must unquantized, " + "but got scaling_mode=", + to_string(input.scaling_mode)); + NVTE_CHECK(!is_fp8_dtype(input.data.dtype), + "Input tensor for amax computation must be unquantized, but got dtype=", + to_string(input.data.dtype)); + NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data"); + CheckInputTensor(input, "input_compute_amax"); + + // Check output tensor + NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); + auto &output = *reinterpret_cast(output_); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + "but got scaling_mode=", + to_string(output.scaling_mode)); + NVTE_CHECK(output.amax.numel() == 1, + "Output tensor for amax computation has invalid amax tensor " + "(expected 1 entry, got shape=", + output.amax.shape, ")"); + NVTE_CHECK(output.amax.dptr != nullptr, + "Output tensor for amax computation has amax tensor without data"); + NVTE_CHECK(output.amax.dtype == DType::kFloat32, + "Output tensor for amax computation has invalid amax tensor " + "(expected FP32, got dtype=", + to_string(output.amax.dtype), ")"); + CheckOutputTensor(output, "output_compute_amax"); + + // Compute amax + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); + launch_amax_kernel(reinterpret_cast(input.data.dptr), + reinterpret_cast(output.amax.dptr), input.data.numel(), + stream);); // NOLINT(*) +} + +namespace transformer_engine { +namespace { + +__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, + const float max_fp8, const bool force_pow_2_scales, + const float epsilon) { + float amax = *amax_ptr; + if (amax < epsilon) { + amax = epsilon; + } + + float scale = 1.f; + + if (isinf(amax) || amax == 0.f) { + *scale_ptr = scale; + return; + } + + scale = max_fp8 / amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + // use fp32 max to represent the scale + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + if (force_pow_2_scales) { + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + + *scale_ptr = scale; +} + +} // namespace +} // namespace transformer_engine + +void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_scale_from_amax); + using namespace transformer_engine; + + // Check output tensor + NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); + auto &output = *reinterpret_cast(output_); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Tensor must be FP8 tensor with per-tensor scaling, " + "but got scaling_mode=", + to_string(output.scaling_mode)); + NVTE_CHECK(is_fp8_dtype(output.data.dtype), + "Tensor must be FP8, but got dtype=", to_string(output.data.dtype)); + NVTE_CHECK(output.amax.numel() == 1, + "Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape, + ")"); + NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data"); + NVTE_CHECK(output.amax.dtype == DType::kFloat32, + "Tensor has invalid amax tensor (expected FP32, got dtype=", + to_string(output.amax.dtype), ")"); + NVTE_CHECK(output.scale.numel() == 1, + "Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape, + ")"); + NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data"); + NVTE_CHECK(output.scale.dtype == DType::kFloat32, + "Tensor has invalid scale tensor (expected FP32, got dtype=", + to_string(output.scale.dtype), ")"); + + // Check config + NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)"); + const auto &config = *reinterpret_cast(config_); + + // Maximum FP8 value + float max_fp8 = 0.f; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, + max_fp8 = Quantized_Limits::max_norm;); + + // Update scale + compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast(output.amax.dptr), + reinterpret_cast(output.scale.dptr), max_fp8, + config.force_pow_2_scales, config.amax_epsilon); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 54d5b0b5bf..23f272d5d5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -6,6 +6,7 @@ #include +#include #include #include "common.h" @@ -150,8 +151,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax - if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor"); + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) { NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, @@ -410,3 +410,79 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); } } + +NVTEQuantizationConfig nvte_create_quantization_config() { + return new transformer_engine::QuantizationConfig; +} + +void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, + "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for quantization config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEQuantizationConfigForcePow2Scales: + std::memcpy(buf, &config_.force_pow_2_scales, attr_size); + break; + case kNVTEQuantizationConfigAmaxEpsilon: + std::memcpy(buf, &config_.amax_epsilon, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, const void *buf, + size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, + "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for quantization config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEQuantizationConfigForcePow2Scales: + std::memcpy(&config_.force_pow_2_scales, buf, attr_size); + break; + case kNVTEQuantizationConfigAmaxEpsilon: + std::memcpy(&config_.amax_epsilon, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 4cdb39b70a..7f3b9fb302 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -249,7 +249,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output.dtype(), OutputType, - if (is_delayed_tensor_scaling(output.scaling_mode)) { + if (is_tensor_scaling(output.scaling_mode)) { + // delayed scaling and current scaling are two variants of per-tensor scaling + constexpr const char *itype_name = TypeInfo::name; constexpr const char *otype_name = TypeInfo::name; constexpr size_t itype_size = sizeof(InputType); @@ -323,6 +325,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel <<>>( static_cast(input.data.dptr), diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index b4b86fe708..ba2890ada3 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1054,8 +1054,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), @@ -1079,8 +1078,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp input->data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryGradKernelLauncher( reinterpret_cast(grad.data.dptr), @@ -1164,14 +1162,22 @@ template scaling_mode) || IS_DBIAS) { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + + if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + " on GPU with compute capability < 10.0."); } - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 63ce369892..227b3aaa48 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -844,7 +844,7 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war staging[warpid] = my_warp_max; } __syncthreads(); - compute_t result = 0; + compute_t result = 0.f; if (warpid == 0) { const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; result = warp_reduce_max(my_max); diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index ff475caf21..543b1181cb 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -24,6 +24,16 @@ torch.bfloat16: tex.DType.kBFloat16, } +TE_DType_To_Torch = { + tex.DType.kByte: torch.uint8, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, +} + AttnMaskTypes = ( "no_mask", "padding", diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 5775fe381d..23137a1003 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -46,15 +46,22 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); std::unique_ptr my_quantizer = convert_quantizer(quantizer); + // check for both quantizer & tensor type: + // mxfp8 tensor -> mxfp8 quantizer + // float8 tensor -> delayed scaling quantizer OR current scaling quantizer + // also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer for (auto [check_type, check_quantizer_type, create_tensor, _] : detail::custom_types_converters) { if (check_type(tensor.ptr())) { - NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()), - "Unexpected quantization params type."); + if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) { + continue; + } auto x = create_tensor(tensor, my_quantizer.get()); return x; } } + NVTE_CHECK(dynamic_cast(my_quantizer.get()) != nullptr, + "Unexpected quantization params type."); // Regular pyTorch tensor at::Tensor torch_tensor = tensor.cast(); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 40245cf2d9..980b2dff13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -50,6 +50,9 @@ namespace transformer_engine::pytorch { +// in python we have: dist_group_type = torch.distributed.ProcessGroup +using dist_group_type = c10d::ProcessGroup; + // Each tensor here is shape (N, ) holding all scaling // data for a single FP8 block, e.g. LayerNormLinear class FP8TensorMeta { @@ -136,6 +139,29 @@ class Float8Quantizer : public Quantizer { std::optional rowwise_data = std::nullopt) const override; }; +class Float8CurrentScalingQuantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + int amax_reduction_size; + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + + explicit Float8CurrentScalingQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + class MXFP8Quantizer : public Quantizer { public: DType dtype; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7ce33ee77b..1ef6f5258d 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common.h" #include "extensions.h" #include "pybind.h" @@ -24,7 +25,35 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto [te_output, out] = my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // for current scaling, we need to compute amax first and then quantize + // because cache cannot fit in the entire tensor to compute amax and quantize + // the quantizer should not need amax reduction, no process group needed here + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // activation function might change the input data range, we need to first call the activation function + // and then find the amax and scale of that and then do the quantization + // get a NoneQuantizer to calculate amax of activation output + auto my_quantizer_none = std::make_unique(py::none()); + auto [te_output_act, out_act] = + my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); + act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); + // use te_output_act as input to the compute amax and find the amax of activated tensor + nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + if (my_quantizer_cs->with_amax_reduction) { + NVTE_ERROR( + "per-tensor current scaling amax reduction is not supported in activation functions."); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + } else { + act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + } return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 66dafdaafb..2c3ccff154 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -45,6 +45,29 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob } if (te_output.numel() == 0) return out; + + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0604847235..3e944c0fdd 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -24,6 +24,7 @@ namespace transformer_engine::pytorch { PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *Float8TensorBasePythonClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr; +PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; @@ -33,6 +34,8 @@ void init_float8_extension() { auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); + Float8CurrentScalingQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer")); Float8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); auto fp8_base_module = diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index effeb8cb4d..427bf294d3 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -140,6 +140,123 @@ std::pair Float8Quantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) + : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + // For current scaling, need several other components: + // 1. with_amax_reduction: bool + // 2. amax_reduction_group: torch.distributed.ProcessGroup or None + // 3. amax_reduction_size: int + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group"); + const c10::intrusive_ptr amax_reduction_group = + amax_reduction_group_obj.is_none() + ? nullptr + : amax_reduction_group_obj.cast>(); + const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + this->amax_reduction_size = amax_reduction_size; + + // fp8 current scaling specific quantization params + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); +} + +void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them) + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + // quantize output and its transpose + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8CurrentScalingQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + + //unlike delayed scaling, in current scaling, scale is not known, so scale_inv should be empty buffer + opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts); + + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); } diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 316e6515bf..b127b5d75b 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -12,7 +12,7 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { return; } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d2607e4ed0..27d5869704 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -23,7 +23,8 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer if (transpose_valid) { transpose = tensor.attr("_transpose").cast>(); } - + // In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer + // whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING auto ret = TensorWrapper(quantizer->get_scaling_mode()); ret.set_rowwise_data(data.data_ptr(), dtype, shape); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 0679528b94..b0f55d7598 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -21,6 +21,7 @@ namespace transformer_engine::pytorch { extern PyTypeObject *Float8TensorPythonClass; extern PyTypeObject *Float8TensorBasePythonClass; extern PyTypeObject *Float8QuantizerClass; +extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8QuantizerClass; @@ -33,13 +34,17 @@ void init_mxfp8_extension(); namespace detail { -inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } +inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } + +inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { + return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass; +} inline bool IsFloat8Tensor(PyObject *obj) { return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; } -inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } +inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; @@ -61,9 +66,11 @@ inline bool IsFloatingPointType(at::ScalarType type) { } constexpr std::array custom_types_converters = { - std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, + std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, CreateQuantizer)}; } // namespace detail diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fe023208d1..c1fc15968b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -21,7 +21,7 @@ from .utils import safely_set_viewless_tensor_data from .constants import dist_group_type from .fp8 import FP8GlobalStateManager -from .tensor.float8_tensor import Float8Quantizer, Float8Tensor +from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase @@ -859,7 +859,10 @@ def _all_gather_fp8( # Quantize input tensor if needed if not isinstance(input_, Float8TensorBase): - assert isinstance(quantizer, Float8Quantizer) + assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + # we cannot directly gather the transposed fp8 tensor + # so we need to disable columnwise usage for the quantizer + # and then set it back to the original value after quantizing init_columnwise_usage = quantizer.columnwise_usage quantizer.set_usage(columnwise=False) input_ = quantizer(input_) @@ -867,7 +870,7 @@ def _all_gather_fp8( # Construct output tensor out: Float8TensorBase - if isinstance(quantizer, Float8Quantizer): + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): dtype = torch.float32 device = "cuda" if isinstance(input_, Float8Tensor): @@ -885,6 +888,9 @@ def _all_gather_fp8( out._transpose_invalid = True else: raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + # For delayed scaling, scale_inv is from history, so we can pass it from input_ to out + # For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv, + # so we can just pass it from input_ to out out._scale_inv = input_._scale_inv # Perform communication @@ -999,8 +1005,10 @@ def gather_along_first_dim( out_shape = list(input_.size()) out_shape[0] *= world_size - # FP8 case - if isinstance(input_, Float8TensorBase) or isinstance(quantizer, Float8Quantizer): + # FP8 case: delayed scaling or current scaling + if isinstance(input_, Float8TensorBase) or isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): return _all_gather_fp8( input_, process_group, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f788368112..87298c2ec7 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -13,7 +13,13 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, +) from .constants import dist_group_type from .utils import get_device_compute_capability @@ -198,6 +204,8 @@ def add_fp8_tensors_to_global_buffer( fp8_meta: Dict[str, Any], ) -> None: """ + Delayed scaling only. + The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is to call this function in order to append it's FP8 tensor into a global @@ -211,7 +219,8 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return # Every module must call this function exactly once since @@ -326,7 +335,8 @@ def reduce_and_update_fp8_tensors( cls, forward: bool = True, ) -> None: - """Concatenate, reduce, and split amaxes in the global buffer.""" + """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" + # global_amax_buffer should only be non-empty for fp8 delayed scaling for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) @@ -426,6 +436,8 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + # delayed scaling only function, for other recipes (current scaling with any granularity), + # this is noop for other recipes because cls.global_amax_buffer is empty list cls.reduce_and_update_fp8_tensors(forward=True) @classmethod @@ -434,7 +446,8 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - to ensure both forward steps are numerically same. """ - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -459,8 +472,8 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non """Switch to the copied scaling factors and amaxes from phase 1 forward for indentical numerical outputs. """ - - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return # Store updated amaxes and scales from phase 1 post forward. @@ -478,8 +491,8 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - - if fp8_meta["recipe"].mxfp8(): + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) @@ -743,6 +756,8 @@ def create( cls = DelayedScalingRecipeState elif recipe.mxfp8(): cls = MXFP8BlockScalingRecipeState + elif recipe.float8_current_scaling(): + cls = Float8CurrentScalingRecipeState else: raise ValueError("{recipe.__class__.__name__} is not supported") return cls( @@ -813,6 +828,45 @@ def make_quantizers(self) -> list: ] +class Float8CurrentScalingRecipeState(RecipeState): + """Configuration for Per-tensor current scaling quantization. + + Per-tensor current quantization does not require state. + + """ + + recipe: Float8CurrentScaling + mode: str + dtype: tex.DType + device: torch.device + + def __init__( + self, + recipe: Float8CurrentScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + from .tensor.float8_tensor import Float8CurrentScalingQuantizer + + return [ + Float8CurrentScalingQuantizer(self.dtype, device=self.device) + for i in range(self.num_quantizers) + ] + + class MXFP8BlockScalingRecipeState(RecipeState): """Configuration for MXFP8 quantization. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 84326f58ea..a44e209d36 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,6 +21,7 @@ from ..fp8 import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, + Float8CurrentScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -34,6 +35,7 @@ from ..tensor import QuantizedTensor, Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer __all__ = ["initialize_ub", "destroy_ub"] @@ -430,7 +432,10 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: - """Increase or decrease size of amax history based on given `length`. + """ + Delayed scaling only. + + Increase or decrease size of amax history based on given `length`. .. warning:: This changes the underlying amax memory location. @@ -489,6 +494,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): return + if recipe.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -851,6 +860,9 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + # FP8 current scaling does not support fused cast + dbias + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 10b21f25c6..8bf420ab0e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -88,6 +88,9 @@ def forward( # TODO Support MXFP8 # pylint: disable=fixme if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): raise NotImplementedError("GroupedLinear does not yet support MXFP8") + # TODO Support Float8 Current Scaling # pylint: disable=fixme + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2608fedeb1..7571b17c1f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, get_ub, @@ -55,8 +56,8 @@ restore_from_saved, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param - from ..cpp_extensions import ( general_gemm, ) @@ -159,6 +160,11 @@ def forward( # Configure quantizer for normalization output with_quantized_norm = fp8 and not return_layernorm_output + # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer + # so we need to set with_quantized_norm to False + if isinstance(input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False + if with_quantized_norm: if with_input_all_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -210,6 +216,10 @@ def forward( with_quantized_all_gather = False if fp8: input_quantizer.set_usage(rowwise=True, columnwise=False) + # ln_out in this has two possibilities: + # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel + # 2. in high precision, then we need to cast it and then gather in FP8 + # the output ln_out_total will be in FP8, and it's a full tensor ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, @@ -290,6 +300,12 @@ def forward( ln_out_total = ub_obj.get_buffer(input_quantizer) nvtx_range_push(f"{nvtx_label}.gemm") + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + out, *_, rs_out = general_gemm( weightmat, ln_out_total, @@ -297,7 +313,7 @@ def forward( quantization_params=output_quantizer, out_dtype=activation_dtype, bias=bias, - use_split_accumulator=_2X_ACC_FPROP, + use_split_accumulator=fprop_gemm_use_split_accumulator, ub=ub_obj, ub_type=ub_type, extra_output=rs_out, @@ -359,6 +375,7 @@ def forward( ctx.weight = weight ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -431,11 +448,12 @@ def backward( ctx.ub_bulk_wgrad, ] ) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + and (ctx.fp8_recipe is not None) ): - raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" - ) + if not ctx.fp8_recipe.delayed(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) saved_tensors = ctx.saved_tensors ( # pylint: disable=unbalanced-tuple-unpacking @@ -572,6 +590,12 @@ def backward( ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + dgrad, *_ = general_gemm( weight, grad_output, @@ -581,7 +605,7 @@ def backward( quantization_params=ctx.grad_input_quantizer, out=dgrad_bulk, out_dtype=ctx.activation_dtype, - use_split_accumulator=_2X_ACC_DGRAD, + use_split_accumulator=dgrad_gemm_use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, extra_output=rs_out, @@ -643,6 +667,14 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) + wgrad, grad_bias_, *_, rs_out = general_gemm( ln_out_total, grad_output, @@ -654,7 +686,7 @@ def backward( ), bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, + use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, @@ -1139,6 +1171,16 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif other recipes (mxfp8, etc) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1332,3 +1374,44 @@ def _get_quantizers(self, fp8_output): grad_output_quantizer, grad_input_quantizer, ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f4ee0a1155..9bb76cb391 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, _ub_communicators, @@ -59,7 +60,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param - +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -73,17 +74,53 @@ __all__ = ["LayerNormMLP"] -def _act_func(activation: str): - funcs = { - "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), - "relu": (tex.relu, tex.drelu, tex.dbias_drelu), +def _get_act_func_supported_list(recipe: Optional[Recipe] = None): + if recipe is None: + # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + return { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + } + if recipe.delayed() or recipe.mxfp8(): + # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + return { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + } + # no activation fusion written yet + # Per-tensor current scaling: [] + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), "reglu": (tex.reglu, tex.dreglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), - "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), - "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + "srelu": (tex.srelu, tex.dsrelu, None), } + + +def _act_func(activation: str, recipe: Optional[Recipe] = None): + # based on each quantization mode, we have different kernel fusion supported: + # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # Per-tensor current scaling: [] + funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") return funcs[activation] @@ -161,7 +198,9 @@ def forward( "Comm+GEMM overlap is only supported with FP8 delayed scaling" ) - activation_func = _act_func(activation)[0] + activation_func = _act_func( + activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + )[0] device = inp.device # Cast for native AMP @@ -175,6 +214,8 @@ def forward( # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned with_quantized_norm = fp8 and not return_layernorm_output + if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output @@ -220,6 +261,8 @@ def forward( zero_centered_gamma, ) + ln_out_return = ln_out if return_layernorm_output else None + # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False @@ -229,6 +272,10 @@ def forward( with_quantized_all_gather = False if fp8: fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + # ln_out in this has two possibilities: + # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel + # 2. in high precision, then we need to cast it and then gather in FP8 + # the output ln_out_total will be in FP8, and it's a full tensor ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, @@ -240,26 +287,19 @@ def forward( if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) else: + if fp8: + if not isinstance(ln_out, QuantizedTensor): + fc1_input_quantizer.set_usage( + rowwise=True, columnwise=backwards_needs_fc1_input + ) + ln_out = fc1_input_quantizer(ln_out) + elif backwards_needs_fc1_input: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + # here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer + # or fused into the layernorm kernel + # ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out ln_out_total = ln_out - # If residual connection is after LN, we need `ln_out` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. - ln_out_return = None - if return_layernorm_output: - ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out - if fp8 and not with_quantized_all_gather: - ln_out_total = fc1_input_quantizer(ln_out_total) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[ - slice_start:slice_end, ... - ] # TODO(pgadzinski) - check this # pylint: disable=fixme - else: - ln_out = ln_out_total - # Cast weights to expected dtype fc1_weight_final = fc1_weight fc2_weight_final = fc2_weight @@ -459,6 +499,7 @@ def forward( ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer @@ -546,11 +587,12 @@ def backward( ctx.ub_bulk_wgrad, ] ) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + and (ctx.fp8_recipe is not None) ): - raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" - ) + if not ctx.fp8_recipe.delayed(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) saved_tensors = ctx.saved_tensors ( # pylint: disable=unbalanced-tuple-unpacking @@ -733,22 +775,36 @@ def backward( fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.grad_fc1_output_quantizer is not None: dact = ctx.grad_fc1_output_quantizer(dact) - elif _act_func(ctx.activation)[2] is not None and ctx.fp8: + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and ctx.fp8 + ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func(ctx.activation)[2] + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func(ctx.activation)[1] + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + # TODO zhongboz: per-tensor current scaling has no bgrad fusion for now + if isinstance(ctx.grad_fc1_output_quantizer, Float8CurrentScalingQuantizer): + fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) + dact = ctx.grad_fc1_output_quantizer(dact) + else: + fc1_bias_grad, dact = tex.bgrad_quantize( + dact, ctx.grad_fc1_output_quantizer + ) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -1286,6 +1342,15 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1494,3 +1559,76 @@ def _get_quantizers(self): grad_fc2_output_quantizer, grad_input_quantizer, ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # fc1_input_quantizer: set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # fc2_input_quantizer + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # fc1_weight_quantizer: also set numerical configs about weight + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # fc2_weight_quantizer + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + if self.sequence_parallel and self.set_parallel_mode: + # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_size = self.tp_size diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f07cfb487b..675a8f929b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -11,6 +11,7 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, get_ub, @@ -228,6 +229,12 @@ def forward( inputmat_total = ub_obj.get_buffer(input_quantizer) nvtx_range_push(f"{nvtx_label}.gemm") + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + out, *_, rs_out = general_gemm( weightmat, inputmat_total, @@ -235,7 +242,7 @@ def forward( quantization_params=output_quantizer, out_dtype=out_dtype, bias=bias, - use_split_accumulator=_2X_ACC_FPROP, + use_split_accumulator=fprop_gemm_use_split_accumulator, ub=ub_obj, ub_type=ub_type, extra_output=rs_out, @@ -277,6 +284,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 ctx.input_quantizer = input_quantizer ctx.grad_output_quantizer = grad_output_quantizer @@ -344,11 +352,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.ub_bulk_wgrad, ] ) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + and (ctx.fp8_recipe is not None) ): - raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" - ) + if not ctx.fp8_recipe.delayed(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) saved_tensors = ctx.saved_tensors inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking @@ -483,6 +492,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # dgrad GEMM nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_dgrad.use_split_accumulator + ) + dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -492,7 +509,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantization_params=ctx.grad_input_quantizer, out=dgrad_bulk, out_dtype=ctx.activation_dtype, - use_split_accumulator=_2X_ACC_DGRAD, + use_split_accumulator=dgrad_gemm_use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, extra_output=rs_out, @@ -551,6 +568,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) + wgrad, grad_bias_, _, rs_out = general_gemm( inputmat_total, grad_output, @@ -562,7 +587,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), bias=(bias if (grad_bias is None and not ctx.fp8) else None), out=main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, + use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, @@ -955,6 +980,16 @@ def __init__( else: self.gemm_bias_unfused_add = False + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -1118,3 +1153,56 @@ def _get_quantizers(self, fp8_output, fp8_grad): grad_output_quantizer, grad_input_quantizer, ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # paralle related + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # set grad_output_quantizer with amax epsilon and power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_size = self.tp_size diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5944039cf0..178401f6a6 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -14,6 +14,7 @@ from ..utils import devices_match, non_tn_fp8_gemm_supported from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..constants import dist_group_type aten = torch.ops.aten @@ -166,6 +167,167 @@ def create_tensor_from_data( ) +class Float8CurrentScalingQuantizer(Quantizer): + """Builder class for FP8 tensors with per-tensor current scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized by + multiplying with a scaling factor and casting to FP8. The max-abs + value ("amax") in the tensor is computed directly by scanning the input + high-precision tensor, without the need of any history window. + + Unlike delayed scaling, scale and amax tensors are not needed to initialize the + quantizer, becuse they are simply GPU buffers that will be filled by current + scaling quantization kernels, instead of using values taken from delayed scaling + history window. Therefore, device parameter is needed for tensor allocation. + + Both Float8CurrentScalingQuantizer and Float8Quantizer produces Float8Tensor, + because they are both per-tensor scaling, ie. one scaling factor per tensor. + + """ + + """Scaling factor to multiply when quantizing to FP8""" + scale: torch.Tensor + """Max-abs value from last FP8 cast""" + amax: torch.Tensor + """FP8 datatype""" + dtype: TE_DType + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + amax_reduction_size: Optional[int] + """Options about how to quantize the tensor""" + force_pow_2_scales: bool + amax_epsilon: float + + def __init__( + self, + fp8_dtype: TE_DType, + device: torch.device, + *, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + amax_reduction_size: Optional[int] = 1, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.scale = torch.empty(1, dtype=torch.float32, device=device) + self.amax = torch.empty(1, dtype=torch.float32, device=device) + self.dtype = fp8_dtype + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.amax_reduction_size = amax_reduction_size + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + if not isinstance(dst, Float8Tensor): + raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor") + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update FP8 dtype + dst._fp8_dtype = self.dtype + + return dst + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + + # Allocate FP8 data transpose if needed + data_transpose = None + if self.columnwise_usage: + inner_dim = data.size(-1) + data_transpose = torch.empty( + inner_dim, + data.numel() // inner_dim, + dtype=torch.uint8, + device=device, + ) + + # Construct FP8 tensor + return Float8Tensor( + shape=shape, + dtype=dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=data_transpose, + quantizer=self, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # current scaling don't need to calibrate + return + + def create_tensor_from_data( + self, + data: torch.Tensor, + fake_dtype=torch.float32, + requires_grad: bool = False, + internal: bool = False, + ): + """ + Create Float8Tensor from raw uint8 data, unlike delayed scaling, + self.scale doesn't mean anything, so we are simply creating empty scale_inv + """ + assert data.dtype in [ + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + if internal: + return Float8TensorBase( + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) + return Float8Tensor( + shape=data.shape, + dtype=fake_dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) + + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data @@ -192,7 +354,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): FP8 format. data_transpose: torch.Tensor, optional FP8 transpose data in a uint8 tensor - quantizer: Float8Quantizer, optional + quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional Builder class for FP8 tensors """ @@ -229,10 +391,9 @@ def _get_quantizer(self) -> Quantizer: """ if self._quantizer is not None: return self._quantizer - return Float8Quantizer( - scale=torch.reciprocal(self._scale_inv), - amax=torch.empty(1, dtype=torch.float32, device=self.device), - fp8_dtype=self._fp8_dtype, + # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling) + raise ValueError( + "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable" ) def quantize_( From 5bb771e3eaf15795697afa0a9f8a3653b4ce9b2a Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Sun, 9 Mar 2025 01:38:55 -0800 Subject: [PATCH 192/239] Verified TE2.0 with offloading (#1514) * Verified TE2.0 with offloading Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skipping tests for Ampere and removed child class preparing Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * offloading support for MXFP8 dtype Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed quantized tensor detection mechanism Signed-off-by: Selvaraj Anandaraj * Fix mxfp8 offload, lint errors, and var name Signed-off-by: Kirthi Shankar Sivamani * Supported disabling offloading for quantized tensors Signed-off-by: Selvaraj Anandaraj * bug fix Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed bugs Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added support for None in list of Quantized data tensors Signed-off-by: root * Hopper backward compatibility cleanup Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Coding style nit Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added guards Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_cpu_offloading.py | 60 +++++++++--- transformer_engine/pytorch/cpu_offload.py | 97 +++++++++++++++---- .../pytorch/tensor/float8_tensor.py | 4 - .../pytorch/tensor/mxfp8_tensor.py | 4 - .../pytorch/tensor/quantized_tensor.py | 2 +- 6 files changed, 124 insertions(+), 44 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 56d668bd12..e2fe2c0200 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,6 +24,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 exit $FAIL diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 61b4a2553c..ed7cdda85b 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -7,8 +7,12 @@ from contextlib import nullcontext import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -SIZE = 4096 +# Check if FP8 supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +SIZE = 512 models = { "linear": te.Linear, @@ -18,40 +22,64 @@ def _get_input(): - return torch.empty((1, SIZE, SIZE)).cuda() # input size - 1 * 2048 * 2048 * 4b = 16MB + return torch.empty((128, SIZE, SIZE)).cuda() def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): - torch.cuda.empty_cache() - model = model_cls(SIZE, SIZE, 1) + + input_layer = model_cls(SIZE, SIZE) + hidden_layer = model_cls(SIZE, SIZE) + output_layer = model_cls(SIZE, SIZE) input = _get_input() if cpu_offload: - offload_context, sync_function = te.get_cpu_offload_context(enabled=True) + offload_context, sync_function = te.get_cpu_offload_context( + enabled=True, + num_layers=2, + model_layers=3, + offload_activations=True, + offload_weights=False, + ) else: offload_context = nullcontext() sync_function = lambda x: x with te.fp8_autocast(enabled=fp8), offload_context: - out = model(input) + out = input_layer(input) + out = sync_function(out) + with te.fp8_autocast(enabled=fp8), offload_context: + out = hidden_layer(out) out = sync_function(out) - input.data = torch.Tensor() # delete data from input - out.data = torch.Tensor() # delete data from out + with te.fp8_autocast(enabled=fp8), offload_context: + out = output_layer(out) + out = sync_function(out) + + max_mem_used = torch.cuda.memory_allocated() / 1024**2 + + out.sum().backward() + + del input_layer + del hidden_layer + del output_layer del input del out - torch.cuda.empty_cache() - allocated_memory_mb = torch.cuda.memory_allocated() / 1024**2 - del model - return allocated_memory_mb + torch.cuda.synchronize() + + return max_mem_used -@pytest.mark.parametrize("fp8", [False, True]) + +@pytest.mark.parametrize("fp8", [True, False]) @pytest.mark.parametrize("model_key", models.keys()) def test_cpu_offload(fp8, model_key) -> None: + + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + model_cls = models[model_key] + without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) - torch.cuda.empty_cache() + with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) - assert without_offloading > 30 - assert with_offloading < 10 + assert with_offloading < without_offloading diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index c47130fe78..8294a76742 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -137,9 +137,7 @@ def __init__( super().__init__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push( - tensor.data, **self.handler_extra_kwargs - ) + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) return retrieve_identifier def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: @@ -235,19 +233,15 @@ def on_group_commit_backward(self): @staticmethod def offload(src_tensor, pin_memory=True): """Offload.""" - fp8_offload = isinstance(src_tensor, Float8Tensor) cpu_backup = torch.empty( src_tensor.size(), - dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + dtype=src_tensor.dtype, layout=src_tensor.layout, device="cpu", pin_memory=pin_memory, ) - if fp8_offload: - cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup) return state @@ -311,6 +305,9 @@ def __init__( self.num_layers = num_model_group # Data Structure to maintain reference to activation tensors self.tensor_tag_to_buf = {} + # Data structure to hold the FP8/MXFP8 tensor objects + self.fp8_tensor_object_map = {} + self.float8_transpose_cache_valid = {} # Tracking the number of layers offloaded self.offloaded_group_count = 0 # Core data structure that decides the window for offloading @@ -341,18 +338,46 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: ), ) + is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None)) + if not torch_stray_tensor: + # obtain a unique tensor tag tensor_tag = (self.current_group, self.tensor_count_current_group) self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state - self.tensor_tag_to_state[tensor_tag] = tensor + if is_quantized_tensor: + tensor_list, _ = tensor.prepare_for_saving() + + self.tensor_tag_to_state[tensor_tag] = [] + self.tensor_tag_to_buf[tensor_tag] = [] - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( - tensor - ): - self.tensor_tag_to_buf[tensor_tag] = tensor + self.fp8_tensor_object_map[tensor_tag] = tensor + if isinstance(tensor, Float8Tensor): + self.float8_transpose_cache_valid[tensor_tag] = getattr( + tensor, "_transpose_invalid" + ) + else: + tensor_list = [tensor] + + for t in tensor_list: + if is_quantized_tensor: + self.tensor_tag_to_state[tensor_tag].append(t) + else: + self.tensor_tag_to_state[tensor_tag] = t + + if ( + self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(t) + ): + if is_quantized_tensor: + self.tensor_tag_to_buf[tensor_tag].append(t) + # Need to clear the internal data reference for the quantized tensors + tensor.clear() + else: + self.tensor_tag_to_buf[tensor_tag] = t else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -364,7 +389,14 @@ def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) + + # Handling the quantized tensor case specially here + if isinstance(tensor, list): + self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) + tensor = self.fp8_tensor_object_map.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + # the tensor should have been copied back in on_group_commit_backward() # which invokes bulk_reload_group. assert not isinstance(tensor, tuple) @@ -377,13 +409,23 @@ def bulk_offload_group(self, group_to_offload): group_id, _ = tensor_tag if group_id == group_to_offload: assert not isinstance(state, tuple) - tensor_on_device = state - # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker(tensor_on_device): - state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) - self.tensor_tag_to_state[tensor_tag] = state - tensor_on_device.data = torch.Tensor() # Force to release memory + is_quantized_tensor = isinstance(state, list) + + if is_quantized_tensor: + tensor_list = state + self.tensor_tag_to_state[tensor_tag] = [] + else: + tensor_list = [state] + + for tensor_on_device in tensor_list: + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + if is_quantized_tensor: + self.tensor_tag_to_state[tensor_tag].append(state) + else: + self.tensor_tag_to_state[tensor_tag] = state def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" @@ -433,6 +475,23 @@ def bulk_reload_group(self, group_to_reload): if isinstance(state, tuple): recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) self.tensor_tag_to_state[tensor_label] = recovered_tensor + elif isinstance(state, list): + tensor_list = [] + for state_tuple in state: + if isinstance(state_tuple, tuple): + tensor_list.append( + SynchronizedGroupOffloadHandler.reload(state_tuple) + ) + else: + tensor_list.append(state_tuple) + _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list) + if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): + self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( + self.float8_transpose_cache_valid.pop(tensor_label) + ) + self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( + tensor_label + ) def on_group_commit_backward(self): # first decrement the current group. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 178401f6a6..9b1e3f3dc4 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -509,10 +509,6 @@ def clear(self): self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose_invalid = True - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: - """Prepare the tensor base for saving for backward""" - return [self], None - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index db369de803..6e3835fbef 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -285,10 +285,6 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: - """Prepare the tensor base for saving for backward""" - return [self], None - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index b540cd91a1..00815452da 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -27,7 +27,7 @@ def prepare_for_saving( if tensor is None: tensor_list.append(None) tensor_objects_list.append(None) - elif type(tensor) in (torch.Tensor, torch.nn.Parameter): + elif isinstance(tensor, torch.Tensor): tensor_list.append(tensor) tensor_objects_list.append(None) else: From b3e70353a424aaaff873eeeb6acb9b45ccb889c1 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 10 Mar 2025 09:19:48 -0700 Subject: [PATCH 193/239] Use internal quantizer for input to the modules (#1551) Internal quantizer for input to the modules Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7571b17c1f..97ab2a1107 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1358,7 +1358,7 @@ def _get_quantizers(self, fp8_output): grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = False + input_quantizer.internal = True weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if fp8_output: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9bb76cb391..bab4b9b44c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1528,7 +1528,7 @@ def _get_quantizers(self): ) = [None] * 8 if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - fc1_input_quantizer.internal = False # temporary + fc1_input_quantizer.internal = True fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 675a8f929b..3aa8113dec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1136,7 +1136,7 @@ def _get_quantizers(self, fp8_output, fp8_grad): grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = False + input_quantizer.internal = True weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if fp8_output: From f0905517b36489147becf119a77a783bb546d6e5 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 10 Mar 2025 11:21:02 -0700 Subject: [PATCH 194/239] [PyTorch] Remove Megatron-LM convergence test (#1521) Remove Megatron-LM convergence test Signed-off-by: Tim Moon Co-authored-by: Kirthi Shankar Sivamani --- qa/L3_pytorch_convergence_test/test.sh | 14 -- tests/pytorch/distributed/print_logs.py | 132 ------------------ .../distributed/run_megatron_lm_gpt.sh | 120 ---------------- tests/pytorch/distributed/test_convergence.py | 112 --------------- 4 files changed, 378 deletions(-) delete mode 100644 qa/L3_pytorch_convergence_test/test.sh delete mode 100644 tests/pytorch/distributed/print_logs.py delete mode 100755 tests/pytorch/distributed/run_megatron_lm_gpt.sh delete mode 100644 tests/pytorch/distributed/test_convergence.py diff --git a/qa/L3_pytorch_convergence_test/test.sh b/qa/L3_pytorch_convergence_test/test.sh deleted file mode 100644 index 110e26cc8a..0000000000 --- a/qa/L3_pytorch_convergence_test/test.sh +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: ${TE_PATH:=/opt/transformerengine} - -pip install prettytable -git clone https://github.com/NVIDIA/Megatron-LM.git -cd Megatron-LM -git checkout b3375a0e38c10e2300ef4be031f7dcabab52b448 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py -python $TE_PATH/tests/pytorch/distributed/print_logs.py diff --git a/tests/pytorch/distributed/print_logs.py b/tests/pytorch/distributed/print_logs.py deleted file mode 100644 index 9d3cb3838f..0000000000 --- a/tests/pytorch/distributed/print_logs.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import os -import re -import glob -import datetime -from prettytable import PrettyTable -from matplotlib import pyplot as plt - -NUM_MOST_RECENT_RUNS = 100 - - -te_path = os.getenv("TE_PATH", "/opt/transformerengine") -mlm_log_dir = os.path.join(te_path, "ci_logs") -te_ci_log_dir = "/data/transformer_engine_ci_logs" -te_ci_plot_dir = os.path.join(te_ci_log_dir, "plots") - - -convergence_pattern = ( - "validation loss at iteration \d* on validation set | lm loss" - " value: ([\d.]*)E\+(\d*) | lm loss PPL: ([\d.]*)E\+(\d*)" -) - - -perf_pattern = "elapsed time per iteration \(ms\): ([\d.]*)" - - -def get_output_file(): - now = datetime.datetime.now() - default_fname = f"unknown_pipeline_id_{now.month}_{now.day}_{now.year}_{now.hour}_{now.minute}" - fname = f"{os.getenv('CI_PIPELINE_ID', default_fname)}.txt" - return os.path.join(te_ci_log_dir, fname) - - -def get_run_metrics(filename): - """Return the loss, perplexity, and step time for a given megatron-LM logfile.""" - - with open(filename, "r") as f: - data = f.read() - - # Loss and PPL - convergence_matches = re.findall(convergence_pattern, data) - loss = round(float(convergence_matches[1][0]) * (10 ** int(convergence_matches[1][1])), 2) - ppl = round(float(convergence_matches[2][2]) * (10 ** int(convergence_matches[2][3])), 2) - - step_times_str = re.findall(perf_pattern, data) - step_times = [float(x) for x in step_times_str] - avg_step_time = round(sum(step_times) / len(step_times), 2) - return loss, ppl, avg_step_time - - -def print_run_logs(): - tables = [] - raw_logs = [] - for model_config in os.listdir(mlm_log_dir): - model_config_dir = os.path.join(mlm_log_dir, model_config) - table = PrettyTable() - table.title = model_config - table.field_names = ["Config", "Loss", "Perplexity", "Avg time per step (ms)"] - for exp in os.listdir(model_config_dir): - filename = os.path.join(model_config_dir, exp) - loss, ppl, time_per_step = get_run_metrics(filename) - exp_name = exp[:-4] - table.add_row([exp_name, loss, ppl, time_per_step]) - raw_logs.append(f"{model_config} {exp_name} {loss} {ppl} {time_per_step}\n") - tables.append(table) - - with open(get_output_file(), "w") as f: - for raw_log in raw_logs: - f.write(raw_log) - for table in tables: - print(table) - - -def save_plot(title, legend, data, filename, ylabel): - x = list(range(1, len(data[0]) + 1)) - plt.figure() - for label, y in zip(legend, data): - plt.plot(x, y, "-o", label=label) - plt.title(title) - plt.legend() - plt.xlabel(f"Last {NUM_MOST_RECENT_RUNS} runs") - plt.ylabel(ylabel) - plt.savefig(os.path.join(te_ci_plot_dir, filename)) - - -def perf_and_loss_plots(): - files = glob.glob(os.path.join(te_ci_log_dir, "*.txt")) - files.sort(key=os.path.getctime) - files = files[-NUM_MOST_RECENT_RUNS:] - data = {} - for filename in files: - with open(filename) as file: - for line in file: - line = line.strip() - model_config, exp_name, loss, _, time_per_step = line.split(" ") - if model_config not in data: - data[model_config] = {} - if exp_name not in data[model_config]: - data[model_config][exp_name] = {"loss": [], "perf": []} - data[model_config][exp_name]["loss"].append(float(loss)) - data[model_config][exp_name]["perf"].append(float(time_per_step)) - - for model_config, experiments in data.items(): - lm_loss_data = [] - lm_perf_data = [] - legend = [] - for exp_name, lm_data in experiments.items(): - legend.append(exp_name) - lm_loss_data.append(lm_data["loss"]) - lm_perf_data.append(lm_data["perf"]) - save_plot( - model_config + " loss", - legend, - lm_loss_data, - model_config + "_loss.png", - "LM-Loss", - ) - save_plot( - model_config + " perf", - legend, - lm_perf_data, - model_config + "_perf.png", - "Time per step (ms)", - ) - - -if __name__ == "__main__": - print_run_logs() - perf_and_loss_plots() diff --git a/tests/pytorch/distributed/run_megatron_lm_gpt.sh b/tests/pytorch/distributed/run_megatron_lm_gpt.sh deleted file mode 100755 index 356399662c..0000000000 --- a/tests/pytorch/distributed/run_megatron_lm_gpt.sh +++ /dev/null @@ -1,120 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -# This script allows flexibly running various sizes of -# GPT3 models with named hyperparameters. - -# Trick to get kwargs. -for ARGUMENT in "$@" -do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" -done - -# Set defaults for all arguments. -: ${DP_SIZE:="1"} -: ${TP_SIZE:="1"} -: ${PP_SIZE:="1"} -: ${NUM_LAYERS:="12"} -: ${HIDDEN_SIZE:="768"} -: ${NHEADS:="12"} -: ${SEQLEN:="2048"} -: ${MAX_POSITION_EMBEDDINGS:="2048"} -: ${MBS:="8"} -: ${GBS:="32"} -: ${STEPS:="400"} -: ${LR:="6.0e-4"} -: ${MIN_LR:="6.0e-5"} -: ${SAVE_INTERVAL:="1000"} -: ${SPLIT:="98,2,0"} -: ${CLIP_GRAD:="1.0"} -: ${WEIGHT_DECAY:="0.1"} -: ${ADAM_BETA1:="0.9"} -: ${ADAM_BETA2:="0.95"} -: ${INIT_METHOD_STD:="0.023"} -: ${SP:="False"} -: ${DTYPE:="bf16"} -: ${WGRAD_FUSION:="True"} -: ${FP8:="False"} -: ${FP8_AMAX_HISTORY_LEN:="32"} -: ${TRANSFORMER_IMPL:="transformer_engine"} -: ${FILENAME:="log.txt"} - -# Logging. -DIR=`pwd` -TENSORBOARD_DIR="${DIR}/tensorboard" -CHECKPOINT_DIR="${DIR}/checkpoints" -mkdir -p ${TENSORBOARD_DIR} -mkdir -p ${CHECKPOINT_DIR} - -# Dataset. -. /data/gpt3/pile-cc1-cc2-shuf/gpt3_blend.sh - -# Set GP3 options. -options=" \ - --exit-duration-in-mins 230 \ - --tensor-model-parallel-size ${TP_SIZE} \ - --pipeline-model-parallel-size ${PP_SIZE} \ - --num-layers ${NUM_LAYERS} \ - --hidden-size ${HIDDEN_SIZE} \ - --num-attention-heads ${NHEADS} \ - --seq-length ${SEQLEN} \ - --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ - --micro-batch-size ${MBS} \ - --global-batch-size ${GBS} \ - --train-iters ${STEPS} \ - --lr ${LR} \ - --min-lr ${MIN_LR} \ - --lr-decay-style cosine \ - --log-interval 1 \ - --eval-iters 50 \ - --eval-interval 2000 \ - --data-path ${DATA_BLEND} \ - --vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json \ - --merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \ - --save-interval ${SAVE_INTERVAL} \ - --save ${CHECKPOINT_DIR} \ - --split ${SPLIT} \ - --clip-grad ${CLIP_GRAD} \ - --weight-decay ${WEIGHT_DECAY} \ - --adam-beta1 ${ADAM_BETA1} \ - --adam-beta2 ${ADAM_BETA2} \ - --init-method-std ${INIT_METHOD_STD} \ - --log-params-norm \ - --log-num-zeros-in-grad \ - --transformer-impl ${TRANSFORMER_IMPL} \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - --fp8-margin 0 \ - --fp8-interval 1 \ - --fp8-amax-history-len ${FP8_AMAX_HISTORY_LEN} \ - --fp8-amax-compute-algo max" - -if [[ "$SP" == "True" ]]; then - options+=" --sequence-parallel" -fi - -if [[ "$WGRAD_FUSION" == "False" ]]; then - options+=" --no-gradient-accumulation-fusion" -fi - -if [[ "$FP8" != "False" ]]; then - options+=" --fp8-format ${FP8}" -fi - -if [[ "$DTYPE" != "fp32" ]]; then - options+=" --${DTYPE}" -fi - -# Run GPT3. -NUM_GPUS=$((${DP_SIZE}*${TP_SIZE}*${PP_SIZE})) -NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FLASH_ATTN=1 NVTE_FWD_LAYERNORM_SM_MARGIN=0 NVTE_BWD_LAYERNORM_SM_MARGIN=0 CUDA_DEVICE_MAX_CONNECTIONS=1 NVTE_BIAS_GELU_NVFUSION=0 NVTE_BIAS_DROPOUT_FUSION=0 python -m torch.distributed.launch --use_env --nnodes=1 --nproc_per_node=${NUM_GPUS} ${DIR}/pretrain_gpt.py ${options} 2>&1 | tee $FILENAME - -# Remove checkpoints. -rm -rf ${CHECKPOINT_DIR}/* diff --git a/tests/pytorch/distributed/test_convergence.py b/tests/pytorch/distributed/test_convergence.py deleted file mode 100644 index 2d468cd301..0000000000 --- a/tests/pytorch/distributed/test_convergence.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import functools -import os -import pytest -import subprocess -from dataclasses import asdict, dataclass -from typing import List, Tuple, Union - -import torch - - -@dataclass() -class ModelConfigGPT: - NUM_LAYERS: int = 12 - HIDDEN_SIZE: int = 768 - NHEADS: int = 12 - SEQLEN: int = 2048 - MAX_POSITION_EMBEDDINGS: int = 2048 - LR: float = 6.0e-4 - MIN_LR: float = 6.0e-5 - SPLIT: str = "98,2,0" - CLIP_GRAD: float = 1.0 - WEIGHT_DECAY: float = 0.1 - ADAM_BETA1: float = 0.9 - ADAM_BETA2: float = 0.95 - INIT_METHOD_STD: float = 0.023 - - -model_configs = { - "126m": ModelConfigGPT(), -} - -dtypes = ["bf16"] - - -fp8_recipes = [False, "hybrid"] - - -all_boolean = [True, False] - - -te_path = os.getenv("TE_PATH", "/opt/transformerengine") -mlm_log_dir = os.path.join(te_path, "ci_logs") - - -@functools.lru_cache(maxsize=None) -def get_parallel_configs() -> List[Tuple[int, int]]: - """Returns valid combinations of (tp, pp).""" - sizes = [1, 2, 4] - num_devices = torch.cuda.device_count() - parallel_configs = [] - if num_devices > 1: - for dp in sizes: - for tp in sizes: - for pp in sizes: - if dp * tp * pp == num_devices: - parallel_configs.append((dp, tp, pp)) - return parallel_configs - - -def get_filename( - model: str, dp: int, tp: int, pp: int, sp: bool, use_te: bool, fp8_recipe: Union[bool, str] -) -> str: - sp = tp if sp else 1 - config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}" - config_dir = os.path.join(mlm_log_dir, config) - os.makedirs(config_dir, exist_ok=True) - fname = ( - f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt" - ) - return os.path.join(config_dir, fname) - - -def get_bash_arguments(filename: str, **kwargs) -> List[str]: - args = [] - script_path = os.path.join(te_path, "tests/pytorch/distributed/run_megatron_lm_gpt.sh") - args.append(script_path) - - for k, v in kwargs.items(): - args.append(f"{k}={str(v)}") - args.append(f"FILENAME={filename}") - return args - - -@pytest.mark.parametrize("sp", all_boolean) -@pytest.mark.parametrize("use_te", all_boolean) -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("dp, tp, pp", get_parallel_configs()) -@pytest.mark.parametrize("model", model_configs.keys()) -def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model): - if sp and tp == 1: - pytest.skip("No tensor parallel.") - if fp8_recipe and not use_te: - pytest.skip("TransformerEngine needed for FP8.") - subprocess.run( - get_bash_arguments( - get_filename(model, dp, tp, pp, sp, use_te, fp8_recipe), - DTYPE=dtype, - FP8=fp8_recipe, - SP=sp, - DP_SIZE=dp, - TP_SIZE=tp, - PP_SIZE=pp, - TRANSFORMER_IMPL="transformer_engine" if use_te else "local", - **asdict(model_configs[model]), - ), - check=True, - ) From 314ab9a845dce702bdb50942763fe770809ad51f Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 10 Mar 2025 13:36:16 -0700 Subject: [PATCH 195/239] Disable parallelism in core build test (#1550) Signed-off-by: Tim Moon --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3b6202263b..6653294c59 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,6 +28,7 @@ jobs: run: pip install . -v env: NVTE_FRAMEWORK: none + MAX_JOBS: 1 - name: 'Sanity check' run: python3 -c "import transformer_engine" working-directory: / From f3a009da7b577655fe3564982d3990b06aca53b1 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 10 Mar 2025 16:45:31 -0700 Subject: [PATCH 196/239] Revert "Use internal quantizer for input to the modules" (#1555) Revert "Use internal quantizer for input to the modules (#1551)" This reverts commit b3e70353a424aaaff873eeeb6acb9b45ccb889c1. Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 97ab2a1107..7571b17c1f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1358,7 +1358,7 @@ def _get_quantizers(self, fp8_output): grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = True + input_quantizer.internal = False weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if fp8_output: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bab4b9b44c..9bb76cb391 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1528,7 +1528,7 @@ def _get_quantizers(self): ) = [None] * 8 if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - fc1_input_quantizer.internal = True + fc1_input_quantizer.internal = False # temporary fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3aa8113dec..675a8f929b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1136,7 +1136,7 @@ def _get_quantizers(self, fp8_output, fp8_grad): grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = True + input_quantizer.internal = False weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if fp8_output: From ab4fd3cf84bdbc49eaf6d70ac63f35c1a7e0e3d2 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 12 Mar 2025 20:53:28 +0800 Subject: [PATCH 197/239] =?UTF-8?q?Remove=20xla=5Fignore=5Fchannel=5Fid=20?= =?UTF-8?q?check=20and=20ignore=20Scan=20loop=20warning=20in=20un=E2=80=A6?= =?UTF-8?q?=20(#1540)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove xla_ignore_channel_id check and ignore Scan loop warning in unit test Signed-off-by: Reese Wang --- tests/jax/pytest.ini | 1 + transformer_engine/jax/cpp_extensions/attention.py | 13 ++----------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/jax/pytest.ini b/tests/jax/pytest.ini index 4b1f68aa77..1e835b2187 100644 --- a/tests/jax/pytest.ini +++ b/tests/jax/pytest.ini @@ -24,3 +24,4 @@ filterwarnings= ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning + ignore:Scan loop is disabled for fused ring attention.*:UserWarning diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 103f97827f..125fa96b3c 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -38,7 +38,6 @@ get_padded_spec, get_cudnn_version, is_ffi_enabled, - get_xla_flag, ) from ..sharding import ( global_mesh_resource, @@ -1607,14 +1606,7 @@ class _FusedAttnCPWithP2PHelper: def use_scanloop(): """Returns true if the implementation will use a scan loop for iteration.""" use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) - - # nvbug(4675071): Disable the HLO verifier for channel ID checks. - # A WAR was added to XLA: https://github.com/openxla/xla/pull/16779 - def truthy(val): - return val.lower() in ["1", "true"] - - x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy) - return x + return use_scan def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" @@ -1659,8 +1651,7 @@ def check_supported(self): if not self.use_scanloop(): warnings.warn( "Scan loop is disabled for fused ring attention. To enable set" - " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and" - " add --xla_experimental_ignore_channel_id=true to XLA_FLAGS." + " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment" ) def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: From 8487e506f52c770b12753fe7da75b14ab25dcd63 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 06:44:38 +0800 Subject: [PATCH 198/239] [PyTorch] Fix fused attention backward's FP8 dtypes (#1566) * fix dtypes in fused attn bwd for FP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add comments for dtypes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove redundant qkv_dtype in fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove Nones in bwd returns Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 37 ++++++++++++++++--------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 537b43496f..dce38320a6 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6095,7 +6095,6 @@ def forward( q, k, v, - qkv_dtype, attn_bias, attn_scale, dropout_p, @@ -6116,6 +6115,10 @@ def forward( # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False + + # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn fake_dtype = q.dtype QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( @@ -6154,6 +6157,7 @@ def forward( v_fp8 = QKV_quantizer(v) case _: raise "Invalid qkv_layout " + qkv_layout + # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -6183,6 +6187,8 @@ def forward( out_ret = out_fp8 else: out_ret = out_fp8.dequantize().view(out_fp8.shape) + # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 + # is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn out_save = out_ret if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): @@ -6211,7 +6217,7 @@ def forward( fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: - + # q, k, v, out_ret: torch.float16 or torch.bfloat16 out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -6280,8 +6286,6 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv - ctx.fake_dtype = fake_dtype - ctx.qkv_dtype = qkv_dtype ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill @@ -6305,6 +6309,11 @@ def backward(ctx, d_out): d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2 + fake_dtype = d_out.dtype + d_out = d_out.contiguous() ( q_fp8, @@ -6364,6 +6373,9 @@ def backward(ctx, d_out): d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) + dqkv_dtype = TE_DType[d_out_fp8._data.dtype] + # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn + # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6374,8 +6386,8 @@ def backward(ctx, d_out): v_fp8, out_fp8, d_out_fp8, - ctx.fake_dtype, - ctx.qkv_dtype, + fake_dtype, + dqkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -6393,6 +6405,8 @@ def backward(ctx, d_out): ctx.deterministic, ) + # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 + # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 if not ctx.is_input_fp8: qkv_group = len(ctx.qkv_layout.split("_")) if qkv_group == 1: @@ -6423,6 +6437,8 @@ def backward(ctx, d_out): else: if isinstance(d_out, QuantizedTensor): d_out = d_out.dequantize() + dqkv_dtype = TE_DType[d_out.dtype] + # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6433,8 +6449,8 @@ def backward(ctx, d_out): v, out, d_out, - ctx.fake_dtype, - ctx.qkv_dtype, + fake_dtype, + dqkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -6482,7 +6498,6 @@ def backward(ctx, d_out): None, None, None, - None, ) # else, return (dqkv, dbias) return ( @@ -6496,7 +6511,6 @@ def backward(ctx, d_out): dq, dk, dv, - None, rest[0], None, None, @@ -6695,8 +6709,6 @@ def forward( cu_seqlens_q_padded = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_kv - qkv_dtype = TE_DType[query_layer.dtype] - use_FAv2_bwd = ( self.use_FAv2_bwd and (core_attention_bias_type == "no_bias") @@ -6768,7 +6780,6 @@ def forward( query_layer, key_layer, value_layer, - qkv_dtype, core_attention_bias, self.softmax_scale, self.attention_dropout if self.training else 0.0, From 31f32b3777817606ba5fc1e4daf2991288de1725 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 12 Mar 2025 17:46:17 -0700 Subject: [PATCH 199/239] Explicitly use `python3` and `pip3` executables (#1486) * Explicitly use python3 and pip3 Signed-off-by: Tim Moon * Run pre-commit as Python module Signed-off-by: Tim Moon * Replace some missed references to "python" or "pip" Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- README.rst | 4 +-- docs/installation.rst | 14 ++++---- qa/L0_cppunittest/test.sh | 2 +- qa/L0_jax_distributed_unittest/test.sh | 6 ++-- qa/L0_jax_lint/test.sh | 10 +++--- qa/L0_jax_unittest/test.sh | 16 ++++----- qa/L0_jax_wheel/test.sh | 16 ++++----- qa/L0_license/copyright_checker.py | 4 +-- qa/L0_license/test.sh | 2 +- qa/L0_pytorch_lint/test.sh | 10 +++--- qa/L0_pytorch_unittest/test.sh | 34 +++++++++---------- qa/L0_pytorch_wheel/test.sh | 16 ++++----- qa/L1_pytorch_distributed_unittest/test.sh | 14 ++++---- qa/L1_pytorch_mcore_integration/test.sh | 2 +- qa/L3_pytorch_FA_versions_test/test.sh | 10 +++--- qa/format.sh | 4 +-- tests/cpp/CMakeLists.txt | 2 +- .../distributed/test_comm_gemm_overlap.py | 2 +- .../fused_attn/test_fused_attn_with_cp.py | 2 +- transformer_engine/jax/__init__.py | 8 ++--- transformer_engine/pytorch/__init__.py | 8 ++--- transformer_engine/pytorch/attention.py | 6 ++-- 22 files changed, 96 insertions(+), 96 deletions(-) diff --git a/README.rst b/README.rst index ace8096c42..c4fde5bd11 100644 --- a/README.rst +++ b/README.rst @@ -173,7 +173,7 @@ To install the latest stable version of Transformer Engine, .. code-block:: bash - pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable @@ -184,7 +184,7 @@ Alternatively, the package can be directly installed from .. code-block:: bash - pip install transformer_engine[pytorch] + pip3 install transformer_engine[pytorch] To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). diff --git a/docs/installation.rst b/docs/installation.rst index dc194e12b0..10046d6306 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI =3.8.2" -pip install pytest==8.2.1 +pip3 install "nltk>=3.8.2" +pip3 install pytest==8.2.1 : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py # Test without custom calls -NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py +NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py -pip install -r $TE_PATH/examples/jax/mnist/requirements.txt -pip install -r $TE_PATH/examples/jax/encoder/requirements.txt +pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 71c1ad5b23..48254e24e1 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -6,16 +6,16 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install wheel +pip3 install wheel cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax +pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel wheel unpack dist/* sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" @@ -23,13 +23,13 @@ mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VE wheel pack ${WHL_BASE} rm dist/*.whl mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel cd transformer_engine/jax -NVTE_RELEASE_BUILD=1 python setup.py sdist +NVTE_RELEASE_BUILD=1 python3 setup.py sdist -pip install dist/* +pip3 install dist/* cd $TE_PATH -pip install dist/*.whl --no-deps +pip3 install dist/*.whl --no-deps -python $TE_PATH/tests/jax/test_sanity_import.py +python3 $TE_PATH/tests/jax/test_sanity_import.py diff --git a/qa/L0_license/copyright_checker.py b/qa/L0_license/copyright_checker.py index bfd1973033..a0e137d1ef 100644 --- a/qa/L0_license/copyright_checker.py +++ b/qa/L0_license/copyright_checker.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # coding: utf-8 # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. @@ -12,7 +12,7 @@ import datetime if len(sys.argv) < 2: - print("Usage: python copyright_checker.py ") + print("Usage: python3 copyright_checker.py ") path = sys.argv[1] diff --git a/qa/L0_license/test.sh b/qa/L0_license/test.sh index 4342e22c23..44b9469e55 100644 --- a/qa/L0_license/test.sh +++ b/qa/L0_license/test.sh @@ -6,4 +6,4 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -python $TE_PATH/qa/L0_license/copyright_checker.py $TE_PATH +python3 $TE_PATH/qa/L0_license/copyright_checker.py $TE_PATH diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index 13cf07cafc..81d7822d7f 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -6,19 +6,19 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install cpplint==1.6.0 pylint==3.3.1 +pip3 install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH echo "Checking common API headers" - cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include + python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include echo "Checking C++ files" - cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common - cpplint --recursive transformer_engine/pytorch + python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common + python3 -m cpplint --recursive transformer_engine/pytorch fi if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - pylint --recursive=y transformer_engine/common transformer_engine/pytorch + python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch fi diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e2fe2c0200..ff7527841a 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -6,25 +6,25 @@ set -x : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 +pip3 install pytest==8.2.1 FAIL=0 -pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1 -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 -NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1 -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1 +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1 +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 exit $FAIL diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index f650a30f01..5f583af31e 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -6,16 +6,16 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install wheel +pip3 install wheel cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch +pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel wheel unpack dist/* sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" @@ -23,13 +23,13 @@ mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VE wheel pack ${WHL_BASE} rm dist/*.whl mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel cd transformer_engine/pytorch -NVTE_RELEASE_BUILD=1 python setup.py sdist +NVTE_RELEASE_BUILD=1 python3 setup.py sdist -pip install dist/* +pip3 install dist/* cd $TE_PATH -pip install dist/*.whl --no-deps +pip3 install dist/*.whl --no-deps -python $TE_PATH/tests/pytorch/test_sanity_import.py +python3 $TE_PATH/tests/pytorch/test_sanity_import.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 5e3823d85c..597551abfe 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -4,15 +4,15 @@ : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 +pip3 install pytest==8.2.1 FAIL=0 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1 -# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1 +# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || FAIL=1 ### TODO Debug UB support with te.Sequential +python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1 exit $FAIL diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh index cfc6446909..2200d11455 100644 --- a/qa/L1_pytorch_mcore_integration/test.sh +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -40,7 +40,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 NVTE_BIAS_GELU_NVFUSION=0 NVTE_BIAS_DROPOUT_FUSION=0 -python +python3 -m torch.distributed.launch --use_env --nnodes=1 diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 8ed3002214..f57d055db5 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -6,13 +6,13 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 +pip3 install pytest==8.2.1 # Limit parallel build jobs to avoid overwhelming system resources export MAX_JOBS=4 # Iterate over Flash Attention versions -sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` +sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` if [ $sm_arch -gt 90 ] then FA_versions=(2.7.3) @@ -26,10 +26,10 @@ do # Build Flash Attention if [ "${fa_version}" \< "3.0.0" ] then - pip install flash-attn==${fa_version} + pip3 install flash-attn==${fa_version} else - pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" - python_path=`python -c "import site; print(site.getsitepackages()[0])"` + pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" + python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` mkdir -p $python_path/flashattn_hopper wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py fi diff --git a/qa/format.sh b/qa/format.sh index caaa0ba416..86fd8f1981 100644 --- a/qa/format.sh +++ b/qa/format.sh @@ -11,5 +11,5 @@ set -e cd $TE_PATH -pip install pre-commit -pre-commit run --all-files +pip3 install pre-commit +python3 -m pre_commit run --all-files diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 081cd14eb4..afc80cba43 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -26,7 +26,7 @@ enable_testing() include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) - execute_process(COMMAND bash -c "pip show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" + execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" OUTPUT_VARIABLE TE_LIB_PATH) endif() diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 52420efca5..eb6b5ca8ed 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -34,7 +34,7 @@ NUM_PROCS: int = torch.cuda.device_count() LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] if tex.ubuf_built_with_mpi(): - LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"] + LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python3"] # Fall back on CUDA IPC if the platform does not support CUDA multicast if not tex.device_supports_multicast(): diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 85950347ba..96321043bc 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -41,7 +41,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): args = [ - "python", + "python3", "-m", "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus_per_node), diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 4e38438a97..6dbe9c0e1d 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -35,15 +35,15 @@ def _load_library(): "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[jax]==VERSION'" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using " + "'pip3 install transformer-engine[jax]==VERSION'" ) if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): _logger.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[jax]==VERSION'", + "Could not find package %s. Install transformer-engine using " + "'pip3 install transformer-engine[jax]==VERSION'", module_name, ) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 888836ec7f..166e72506b 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -45,15 +45,15 @@ def _load_library(): "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[pytorch]==VERSION'" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using " + "'pip3 install transformer-engine[pytorch]==VERSION'" ) if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): _logger.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[pytorch]==VERSION'", + "Could not find package %s. Install transformer-engine using " + "'pip3 install transformer-engine[pytorch]==VERSION'", module_name, ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dce38320a6..5e5d4098b6 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -142,7 +142,7 @@ def _get_supported_versions(version_min, version_max): if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: fa_logger.debug( "flash-attn v2 is not installed. To use, please install it by" - """ "pip install flash-attn".""", + """ "pip3 install flash-attn".""", ) else: if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): @@ -197,8 +197,8 @@ def _get_supported_versions(version_min, version_max): # TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved # https://github.com/Dao-AILab/flash-attention/issues/1452 _flash_attn_3_installation_steps = """\ -(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" -(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(1) pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" +(2) python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` (3) mkdir -p $python_path/flashattn_hopper (4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" try: From 0e13788366b4a9cd21eb42df6fafcd11f6c4bd0a Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 13 Mar 2025 19:35:44 +0800 Subject: [PATCH 200/239] [JAX] FFI API compatibility with both 0.4 and 0.5 (#1562) Make ffi compatible with jax 0.4 Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 7 ++++++- .../jax/cpp_extensions/attention.py | 15 ++++++++++----- .../jax/cpp_extensions/custom_call.py | 12 +++++++++--- .../jax/cpp_extensions/normalization.py | 9 +++++++-- .../jax/cpp_extensions/quantization.py | 7 ++++++- transformer_engine/jax/cpp_extensions/softmax.py | 7 ++++++- .../jax/cpp_extensions/transpose.py | 9 +++++++-- 7 files changed, 51 insertions(+), 15 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 704740c56d..c9c40de7e3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,13 +5,13 @@ from typing import Tuple, Sequence, Union, Callable import operator from functools import reduce, partial +from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax import ffi import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type @@ -28,6 +28,11 @@ from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = ["act_lu", "dact_lu", "act_lu_fp8"] diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 125fa96b3c..47425fe6d5 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2,12 +2,13 @@ # # See LICENSE for license information. """JAX/TE custom ops for attention""" -from dataclasses import dataclass, replace -from functools import partial, reduce import operator import os -from typing import Optional, Tuple import warnings +from dataclasses import dataclass, replace +from functools import partial, reduce +from typing import Optional, Tuple +from packaging import version import jax import jax.numpy as jnp @@ -15,8 +16,6 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax import ffi - import transformer_engine_jax from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -50,6 +49,12 @@ ) +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + + __all__ = [ "FusedAttnHelper", "fused_attn_fwd", diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 422d81b267..66b5e1c923 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -4,13 +4,19 @@ """JAX/TE custom call""" from dataclasses import dataclass from enum import IntEnum +from packaging import version import jax from jax.interpreters import mlir -import transformer_engine_jax +import transformer_engine_jax from .misc import is_ffi_enabled +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + try: from jaxlib.hlo_helpers import custom_call except ImportError: @@ -29,11 +35,11 @@ class CustomCallAPIVersion(IntEnum): for _name, _value in transformer_engine_jax.registrations().items(): if _name.endswith("_ffi"): if is_ffi_enabled(): - jax.ffi.register_ffi_target( + ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value ) else: - jax.ffi.register_ffi_target( + ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 50248649ba..ed8f5dde7a 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -2,10 +2,11 @@ # # See LICENSE for license information. """JAX/TE custom ops for normalization""" -from functools import partial, reduce, cache import operator import os import warnings +from functools import partial, reduce, cache +from packaging import version import jax import jax.numpy as jnp @@ -13,7 +14,6 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax import ffi import transformer_engine_jax @@ -30,6 +30,11 @@ from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = [ "layernorm_fwd", diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index f3ecf5e230..d944612ef5 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -3,13 +3,13 @@ # See LICENSE for license information. """JAX/TE custom ops for quantization""" from typing import Tuple +from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax import ffi import transformer_engine_jax from transformer_engine_jax import DType as TEDType @@ -25,6 +25,11 @@ ) from ..sharding import all_reduce_max_along_all_axes_except_PP +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = ["cast_fp8"] diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 42c6919d92..888e6a897a 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -6,13 +6,13 @@ from functools import partial, reduce import operator import warnings +from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax import ffi import transformer_engine_jax @@ -21,6 +21,11 @@ from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled from ..softmax import SoftmaxType +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = [ "scaled_softmax_fwd", diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 8353414235..ca42126e4b 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -2,16 +2,16 @@ # # See LICENSE for license information. """JAX/TE custom ops for transpose""" +import operator from functools import partial, reduce from typing import Tuple, Sequence, Union, Callable -import operator +from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax import ffi import transformer_engine_jax from transformer_engine_jax import DType as TEDType @@ -33,6 +33,11 @@ from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = [ "transpose", From 8a20d666e4d63ff133621987b2bfe06e59711062 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 13 Mar 2025 07:39:21 -0700 Subject: [PATCH 201/239] Support tensors with only column-wise data (#1505) * Delete row-wise data in single-GPU linear forward Signed-off-by: Tim Moon * Debug Python->C++ parsing of transpose-only Float8Tensors Signed-off-by: Tim Moon * Debug tensor shape calculation without row-wise data Signed-off-by: Tim Moon * Debug correctness issues with only column-wise data Signed-off-by: Tim Moon * Only cache column-wise input in LayerNormLinear Signed-off-by: Tim Moon * Support MXFP8 all-gather with only column-wise data Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix moe cases, lint, rm unused ctx Signed-off-by: Kirthi Shankar Sivamani * Fix CPU activation offloading and use consistent logic for save/restore Signed-off-by: Kirthi Shankar Sivamani * Fix tests Signed-off-by: Kirthi Shankar Sivamani * Fix typo Signed-off-by: Kirthi Shankar Sivamani * RM stray file Signed-off-by: Kirthi Shankar Sivamani * Fix distributed and cpp tests Signed-off-by: Kirthi Shankar Sivamani * Fix norm cpp tests Signed-off-by: Kirthi Shankar Sivamani * Rm stray file Signed-off-by: Kirthi Shankar Sivamani * RM stray file Signed-off-by: Kirthi Shankar Sivamani * Fix MXFP8 AG Signed-off-by: Kirthi Shankar Sivamani * Fix FP8 with sequence parallelism Signed-off-by: Kirthi Shankar Sivamani * Fix UB bulk dgrad Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Tim Moon Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- .../cpp/operator/test_normalization_mxfp8.cu | 5 +- tests/cpp/test_common.cu | 9 +- transformer_engine/common/common.h | 88 +++++--- .../transformer_engine/transformer_engine.h | 2 +- .../common/transformer_engine.cpp | 88 +++++--- transformer_engine/pytorch/cpu_offload.py | 16 +- .../csrc/extensions/type_converters.cpp | 85 +++---- transformer_engine/pytorch/distributed.py | 213 +++++++++++------- .../pytorch/module/layernorm_linear.py | 21 +- transformer_engine/pytorch/module/linear.py | 21 +- .../pytorch/ops/basic/basic_linear.py | 4 +- .../tensor/_internal/float8_tensor_base.py | 6 +- .../tensor/_internal/mxfp8_tensor_base.py | 4 +- .../pytorch/tensor/float8_tensor.py | 47 ++-- .../pytorch/tensor/mxfp8_tensor.py | 68 ++++-- .../pytorch/tensor/quantized_tensor.py | 15 +- 16 files changed, 431 insertions(+), 261 deletions(-) diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index d1bdb6203b..191c62835b 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -92,7 +92,10 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training) input.to_cpu(); auto scaling_mode = input.scaling_mode(); assert(input.rowwise_shape().ndim == 2); - assert(input.columnwise_shape().ndim == 2); + + if (is_training) { + assert(input.columnwise_shape().ndim == 2); + } dequantize_1x_kernel(input.rowwise_cpu_dptr(), input.rowwise_cpu_scale_inv_ptr(), diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 24aff83d8a..8565e5d5c6 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -83,9 +83,11 @@ size_t product(const NVTEShape &shape, size_t begin, size_t end) { } return ret; } + size_t product(const NVTEShape &shape) { return product(shape, 0, shape.ndim); } + size_t product(const std::vector shape, size_t begin, size_t end) { size_t ret = 1; NVTE_CHECK(end <= shape.size()); @@ -193,6 +195,7 @@ Tensor::Tensor(const std::string& name, std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); + NVTEShape columnwise_shape{nullptr, 0}; std::vector columnwise_shape_vec; if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { @@ -207,7 +210,11 @@ Tensor::Tensor(const std::string& name, columnwise_shape_vec.emplace_back(shape.data[i]); } } - const NVTEShape columnwise_shape{columnwise_shape_vec.data(), columnwise_shape_vec.size()}; + + if (columnwise) { + columnwise_shape.data = columnwise_shape_vec.data(); + columnwise_shape.ndim = columnwise_shape_vec.size(); + } tensor_ = TensorWrapper(scaling_mode); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 4163505db6..ac58398551 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -29,6 +29,9 @@ namespace transformer_engine { +std::string to_string(const DType type); +std::string to_string(const NVTEScalingMode &mode); + inline bool is_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } @@ -108,17 +111,8 @@ struct Tensor { scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} int numel() const { - NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, - "Tensor does not hold any data!"); size_t acc = 1; - if (data.dptr != nullptr) { - for (const auto &dim : data.shape) { - acc *= dim; - } - return acc; - } - // data is empty, use columnwise_data - for (const auto &dim : columnwise_data.shape) { + for (const auto dim : shape()) { acc *= dim; } return acc; @@ -126,7 +120,10 @@ struct Tensor { bool has_data() const noexcept { return data.dptr != nullptr; } - bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } + // Check for size (not just pointer) for 0-dim or no token cases. + bool has_columnwise_data() const noexcept { + return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; + } DType dtype() const { if (has_data()) return data.dtype; @@ -135,24 +132,54 @@ struct Tensor { return data.dtype; } + std::vector shape() const { + /* Note: We sometimes experience spurious compiler errors + * (-Wstringop-overflow) from this function. It appears that GCC + * has some bugs with std::vector (see + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). + */ + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); + } + return ret; + } else { + return data.shape; + } + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); + return {}; + } + } + /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_first_dim() const { - if (!has_data() && has_columnwise_data()) { - const auto &data_shape = columnwise_data.shape; - if (data_shape.empty()) return 1; - if (is_tensor_scaling(scaling_mode)) { - return product(data_shape, 1, data_shape.size()); - } else { - return product(data_shape, 0, data_shape.size() - 1); + const auto &full_shape = shape(); + size_t ret = 1; + if (!full_shape.empty()) { + for (size_t i = 0; i < full_shape.size() - 1; i++) { + ret *= full_shape[i]; } } - const auto &data_shape = data.shape; - if (data_shape.empty()) return 1; - return product(data_shape, 0, data_shape.size() - 1); + return ret; } /*! Matrix width after tensor is flattened to 2D @@ -161,18 +188,12 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_last_dim() const { - if (!has_data() && has_columnwise_data()) { - const auto &data_shape = columnwise_data.shape; - if (data_shape.empty()) return 1; - if (is_tensor_scaling(scaling_mode)) { - return data_shape.front(); - } else { - return data_shape.back(); - } + const auto &full_shape = shape(); + if (full_shape.empty()) { + return 1; + } else { + return full_shape.back(); } - const auto &data_shape = data.shape; - if (data_shape.empty()) return 1; - return data_shape.back(); } }; @@ -477,9 +498,6 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt bool is_fp8_dtype(const DType t); -std::string to_string(const DType type); -std::string to_string(const NVTEScalingMode &type); - /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e91f3c4836..dd1cfb8ddb 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -555,7 +555,7 @@ class TensorWrapper { * \return Number of elements in the tensor. */ size_t numel() const noexcept { - if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + if (tensor_ == nullptr) return 0; return nvte_tensor_numel(tensor_); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 23f272d5d5..1f8bfca2c9 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -212,37 +212,58 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; - const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - - // FP8 tensor keeps shape in rowwise data - if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - return ret; + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); } + NVTEShape ret; - // Get shape based on what data is available - if (t.has_data()) { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - return ret; - } - if (t.has_columnwise_data()) { - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - return ret; + // Determine tensor shape depending on tensor format + const auto &t = *reinterpret_cast(tensor); + switch (t.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!t.has_data() && t.has_columnwise_data()) { + // We can infer tensor shape if FP8 tensor only has FP8 data + // transpose. However, NVTEShape only contains a pointer and + // cannot store temporary data. We hack around this by caching + // the tensor shape within the empty FP8 data. + auto &shape_cache = const_cast &>(t.data.shape); + shape_cache.clear(); + if (!t.columnwise_data.shape.empty()) { + for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { + shape_cache.push_back(t.columnwise_data.shape[i]); + } + shape_cache.push_back(t.columnwise_data.shape.front()); + } + ret.data = shape_cache.data(); + ret.ndim = shape_cache.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (!t.has_data() && t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", + transformer_engine::to_string(t.scaling_mode), "\""); } - // Tensor has no data - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); return ret; } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); + } const auto &t = *reinterpret_cast(tensor); NVTEShape ret; ret.data = t.columnwise_data.shape.data(); @@ -250,25 +271,20 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndim(const NVTETensor tensor) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); - return t.data.shape.size(); -} +size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); - return t.data.shape[dim]; + const auto &shape = nvte_tensor_shape(tensor); + NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim, + " in a shape array with ", shape.ndim, " entries"); + return shape.data[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); + const auto &shape = nvte_tensor_shape(tensor); size_t numel = 1; - for (auto size : t.data.shape) { - numel *= size; + for (size_t i = 0; i < shape.ndim; i++) { + numel *= shape.data[i]; } return numel; } diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 8294a76742..93df512ac6 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -419,11 +419,25 @@ def bulk_offload_group(self, group_to_offload): tensor_list = [state] for tensor_on_device in tensor_list: + # `tensor_offloaded` is a hacky way of dealing with columnwise-only + # quantized tensors for CPU offloading. The complication is due to + # the `rowwise_data` being `None`. The offloading checker incorrectly + # returns `False` and the entire `state` ([None, columnwise_tensor]) + # is added to the tensor tag state dict. A better design would change + # how quantized tensors are kept track of in the offload handler. + # Currently at every stage it is ensured that a quantized tensor is a + # list whereas a non-quantized tensor is standalone object, which is + # not good! TODO(@sanandaraj5597) + tensor_offloaded = False # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker(tensor_on_device): + tensor_offloaded = True state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) if is_quantized_tensor: - self.tensor_tag_to_state[tensor_tag].append(state) + if tensor_offloaded: + self.tensor_tag_to_state[tensor_tag].append(state) + else: + self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) else: self.tensor_tag_to_state[tensor_tag] = state diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 27d5869704..d5654fb43a 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -4,6 +4,10 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include +#include + #include "common.h" #include "pybind.h" @@ -11,67 +15,72 @@ namespace transformer_engine::pytorch { namespace detail { TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { - const at::Tensor &data = tensor.attr("_data").cast(); - const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast(); - float *scale_inv_dptr = reinterpret_cast(scale_inv.data_ptr()); - const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(quantizer->get_scaling_mode()); + + bool data_exists = !tensor.attr("_data").is_none(); + bool transpose_exists = + !tensor.attr("_transpose_invalid").cast() && !tensor.attr("_transpose").is_none(); - const auto &shape = getTensorShape(data); + NVTE_CHECK(data_exists || transpose_exists, "No data found for FP8 Tensor."); - bool transpose_valid = !tensor.attr("_transpose_invalid").cast(); - std::optional transpose = std::nullopt; - if (transpose_valid) { - transpose = tensor.attr("_transpose").cast>(); + // FP8 data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (data_exists) { + const auto &data = tensor.attr("_data").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); } - // In the case of being called under tex.dequantize, the quantizer will be NoneQuantizer - // whose scaling mode is defaulted to NVTE_DELAYED_TENSOR_SCALING - auto ret = TensorWrapper(quantizer->get_scaling_mode()); - ret.set_rowwise_data(data.data_ptr(), dtype, shape); - if (transpose_valid && transpose != std::nullopt) { - const auto &transpose_shape = getTensorShape(*transpose); - ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); + // FP8 data transpose + if (transpose_exists) { + const auto &data_transpose = tensor.attr("_transpose").cast(); + ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose)); } - const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); - const auto scale_inv_shape = getTensorShape(scale_inv); - ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + // Scale-inverse + { + const auto &scale_inv = tensor.attr("_scale_inv").cast(); + float *dptr = reinterpret_cast(scale_inv.data_ptr()); + const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto &shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(dptr, dtype, shape); + ret.set_columnwise_scale_inv(dptr, dtype, shape); + } + + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { - const DType dtype = tensor.attr("_fp8_dtype").cast(); auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); + + // Row-scaled data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); if (rowwise_usage) { - const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); - const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); - void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto &shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); - - const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); - ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape); + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); } + // Column-scaled data if (columnwise_usage) { - const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); - const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); - void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto &shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); - - const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); - ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0, - scale_inv_colwise_shape); + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, + getTensorShape(scale_inv)); } + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c1fc15968b..2a614f67d7 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -22,7 +22,7 @@ from .constants import dist_group_type from .fp8 import FP8GlobalStateManager from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer -from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase @@ -819,30 +819,30 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False + inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. if world_size == 1: - return input_, None + return inp, None - dim_size = list(input_.size()) + dim_size = list(inp.size()) assert ( dim_size[0] % world_size == 0 ), "First dimension of the tensor should be divisible by tensor parallel size" dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), group=tp_group, async_op=async_op + output, inp.contiguous(), group=tp_group, async_op=async_op ) return output, handle def _all_gather_fp8( - input_: torch.Tensor, + inp: torch.Tensor, process_group: dist_group_type, *, async_op: bool = False, @@ -854,18 +854,18 @@ def _all_gather_fp8( # Output tensor dims if out_shape is None: - out_shape = list(input_.size()) + out_shape = list(inp.size()) out_shape[0] *= world_size # Quantize input tensor if needed - if not isinstance(input_, Float8TensorBase): + if not isinstance(inp, Float8TensorBase): assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer # and then set it back to the original value after quantizing init_columnwise_usage = quantizer.columnwise_usage quantizer.set_usage(columnwise=False) - input_ = quantizer(input_) + inp = quantizer(inp) quantizer.set_usage(columnwise=init_columnwise_usage) # Construct output tensor @@ -873,30 +873,30 @@ def _all_gather_fp8( if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): dtype = torch.float32 device = "cuda" - if isinstance(input_, Float8Tensor): - dtype = input_.dtype - device = input_.device + if isinstance(inp, Float8Tensor): + dtype = inp.dtype + device = inp.device out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - elif isinstance(input_, Float8Tensor): - out = input_.make_like(input_, shape=out_shape) + elif isinstance(inp, Float8Tensor): + out = inp.make_like(inp, shape=out_shape) out._data = torch.empty_like( out_shape, dtype=torch.uint8, - device=input_.device, + device=inp.device, ) out._transpose = None out._transpose_invalid = True else: raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") - # For delayed scaling, scale_inv is from history, so we can pass it from input_ to out + # For delayed scaling, scale_inv is from history, so we can pass it from inp to out # For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv, - # so we can just pass it from input_ to out - out._scale_inv = input_._scale_inv + # so we can just pass it from inp to out + out._scale_inv = inp._scale_inv # Perform communication handle = torch.distributed.all_gather_into_tensor( out._data, - input_._data.contiguous(), + inp._data.contiguous(), group=process_group, async_op=async_op, ) @@ -914,7 +914,7 @@ def _all_gather_fp8( def _all_gather_mxfp8( - input_: torch.Tensor, + inp: torch.Tensor, process_group: dist_group_type, *, async_op: bool = False, @@ -925,27 +925,56 @@ def _all_gather_mxfp8( # Tensor dims world_size = get_distributed_world_size(process_group) - in_shape = list(input_.size()) + in_shape = list(inp.size()) if out_shape is None: out_shape = [in_shape[0] * world_size] + in_shape[1:] - # Gather MXFP8 data for row-wise usage - if quantizer.rowwise_usage and not quantizer.columnwise_usage: + # For cases where inp has dimensions that cannot be quantized, + # we gather in high precision followed by a cast to FP8. + if ( + not isinstance(inp, MXFP8TensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + inp_dtype = inp.dtype + inp_device = inp.device + + # Cast input tensor to MXFP8 with required data + if not isinstance(inp, MXFP8TensorBase): + inp = quantizer(inp) + elif ( + inp.rowwise_data is None + and quantizer.rowwise_usage + or inp.columnwise_data is None + and quantizer.columnwise_usage + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to MXFP8." + ) + inp = quantizer(inp.dequantize()) - # Cast input tensor to MXFP8 if needed - if not isinstance(input_, MXFP8TensorBase): - input_ = quantizer(input_) + # Construct MXFP8 output tensor + out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device) - # Construct MXFP8 output tensor - dtype = torch.float32 - device = "cuda" - if isinstance(input_, MXFP8Tensor): - dtype = input_.dtype - device = input_.device - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Async op handle + handle = None + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage: # Remove padding from MXFP8 scale-inverses - in_scale_inv = input_._rowwise_scale_inv + in_scale_inv = inp._rowwise_scale_inv out_scale_inv = out._rowwise_scale_inv flattened_in_shape0 = math.prod(in_shape[:-1]) if in_scale_inv.size(0) != flattened_in_shape0: @@ -954,40 +983,52 @@ def _all_gather_mxfp8( out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers - with torch.distributed._coalescing_manager( + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: - torch.distributed.all_gather_into_tensor( - out._rowwise_data, - input_._rowwise_data, - group=process_group, - ) - torch.distributed.all_gather_into_tensor( - out_scale_inv, - in_scale_inv, - group=process_group, - ) - handle = coalescing_manager if async_op else None - return out, handle + ) + handle = torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + async_op=async_op, + ) - # Gather in high precision and quantize for column-wise usage - if isinstance(input_, QuantizedTensor): - input_ = input_.dequantize(dtype=torch.bfloat16) - out = torch.empty( - out_shape, - dtype=input_.dtype, - device=input_.device, - memory_format=torch.contiguous_format, - ) - torch.distributed.all_gather_into_tensor(out, input_, group=process_group) - out = quantizer(out) - return out, None + # Gather MXFP8 data for column-wise usage + if quantizer.columnwise_usage: + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = inp._columnwise_scale_inv + out_scale_inv = out._columnwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = torch.distributed.all_gather_into_tensor( + out._columnwise_data, + inp._columnwise_data, + group=process_group, + async_op=async_op, + ) + + return out, handle def gather_along_first_dim( - input_: torch.Tensor, + inp: torch.Tensor, process_group: dist_group_type, async_op: bool = False, quantizer: Optional[Quantizer] = None, @@ -997,20 +1038,20 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(input_, QuantizedTensor): - input_ = quantizer(input_) - return input_, None + if quantizer is not None and not isinstance(inp, QuantizedTensor): + inp = quantizer(inp) + return inp, None # Output tensor dims - out_shape = list(input_.size()) + out_shape = list(inp.size()) out_shape[0] *= world_size # FP8 case: delayed scaling or current scaling - if isinstance(input_, Float8TensorBase) or isinstance( + if isinstance(inp, Float8TensorBase) or isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): return _all_gather_fp8( - input_, + inp, process_group, async_op=async_op, quantizer=quantizer, @@ -1018,10 +1059,10 @@ def gather_along_first_dim( ) # MXFP8 case - if isinstance(input_, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) return _all_gather_mxfp8( - input_, + inp, process_group, async_op=async_op, quantizer=quantizer, @@ -1034,36 +1075,36 @@ def gather_along_first_dim( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(input_, QuantizedTensor): - input_ = input_.dequantize() + if isinstance(inp, QuantizedTensor): + inp = inp.dequantize() out = torch.empty( out_shape, - dtype=input_.dtype, - device=input_.device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) - torch.distributed.all_gather_into_tensor(out, input_, group=process_group) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) out = quantizer(out) return out, None # Dequantize quantized tensor if not supported - if isinstance(input_, QuantizedTensor): + if isinstance(inp, QuantizedTensor): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - input_ = input_.dequantize() + inp = inp.dequantize() # Communication for plain PyTorch tensors out = torch.empty( out_shape, - dtype=input_.dtype, - device=input_.device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) handle = torch.distributed.all_gather_into_tensor( out, - input_.contiguous(), + inp.contiguous(), group=process_group, async_op=async_op, ) @@ -1071,7 +1112,7 @@ def gather_along_first_dim( def allreduce( - input_: torch.Tensor, + inp: torch.Tensor, tp_group: Optional[dist_group_type] = None, async_op: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: @@ -1079,12 +1120,12 @@ def allreduce( # Bypass the function if we are using only 1 GPU. if get_distributed_world_size(tp_group) == 1: - return input_, None + return inp, None # All-reduce. - handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) + handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op) - return input_, handle + return inp, handle def _fsdp_scatter_tensors( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7571b17c1f..1b62f8d777 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -56,6 +56,7 @@ restore_from_saved, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpp_extensions import ( @@ -326,6 +327,19 @@ def forward( clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + ctx.ln_out_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) + + # Input with column-wise usage is needed for dgrad GEMM. + if backward_needs_input: + if isinstance(ln_out, QuantizedTensor): + # For sequence parallel in vanilla FP8, rowwise data is + # to gather the input. For MXFP8, columnwise only data + # can be allgathered. + if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: + ln_out.update_usage(rowwise_usage=False) + if cpu_offloading: if fp8 and weightmat is not None: set_offloading_param(weightmat, "weight_offloading", True) @@ -556,12 +570,7 @@ def backward( # Note: Perform tensor-parallel communication if needed ln_out_total = None ln_out_total_work = None - if ( - ctx.requires_wgrad - and ctx.parallel_mode == "column" - and ctx.sequence_parallel - and not ctx.ub_bulk_dgrad - ): + if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None if ctx.fp8: quantizer = ctx.input_quantizer diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 675a8f929b..4c87396e3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -56,6 +56,7 @@ prepare_for_saving, restore_from_saved, ) +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param @@ -155,6 +156,7 @@ def forward( ) if not isinstance(inputmat, QuantizedTensor): inputmat = input_quantizer(inputmat) + own_quantized_input = True elif backward_needs_input: inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) inputmat_total = inputmat @@ -251,9 +253,18 @@ def forward( if is_grad_enabled: saved_inputmat = None + + ctx.backward_input_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) + if backward_needs_input: if own_quantized_input and isinstance(inputmat, QuantizedTensor): - inputmat.update_usage(rowwise_usage=False) + # For sequence parallel in vanilla FP8, rowwise data is + # to gather the input. For MXFP8, columnwise only data + # can be allgathered. + if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: + inputmat.update_usage(rowwise_usage=False) saved_inputmat = inputmat if cpu_offloading: @@ -311,7 +322,6 @@ def forward( ctx.requires_wgrad = weight.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - ctx.is_input_fp8 = not own_quantized_input if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -452,12 +462,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: Perform tensor-parallel communication if needed inputmat_total = None inputmat_total_work = None - if ( - ctx.requires_wgrad - and ctx.parallel_mode == "column" - and ctx.sequence_parallel - and not ctx.ub_bulk_dgrad - ): + if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None if ctx.fp8: quantizer = ctx.input_quantizer diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 892e120da1..cb93eb5e6b 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -523,9 +523,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - ### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme - # x_local.update_usage(rowwise_usage=False) - pass + x_local.update_usage(rowwise_usage=False) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index bb01b1ee8b..bf518cae22 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -5,6 +5,7 @@ """Mixin class holding data specific for Float8Tensor""" from __future__ import annotations +import math from typing import Any, Dict, Optional, Tuple import torch @@ -120,7 +121,10 @@ def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: def size(self, *args, **kwargs): # pylint: disable=missing-function-docstring - return self._data.size(*args, **kwargs) + if self._data is not None: + return self._data.size(*args, **kwargs) + size = self._transpose.size(*args, **kwargs) + return torch.Size([size[-1], math.prod(size[:-1])]) def __repr__(self): return ( diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index b818638d02..e6dcf1d48f 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -115,7 +115,9 @@ def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: def size(self, *args, **kwargs): # pylint: disable=missing-function-docstring - return self._rowwise_data.size(*args, **kwargs) + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + return self._columnwise_data.size(*args, **kwargs) def __repr__(self): data_rowwise = self.dequantize() diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 9b1e3f3dc4..5bea0398ab 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -428,21 +428,40 @@ def _create_transpose(self): self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose_invalid = False - def update_usage(self, rowwise_usage=True, columnwise_usage=True): - assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" - if rowwise_usage: - assert self._data is not None, "Rowwise usage of the tensor was already disabled" - else: - if not non_tn_fp8_gemm_supported(): - if self._transpose is None or self._transpose_invalid: - self._create_transpose() - self._data = None - if columnwise_usage: - if self._transpose is None or self._transpose_invalid: - assert self._data is not None, "The tensor does not hold any data anymore" - if not non_tn_fp8_gemm_supported(): - self._create_transpose() + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + # Figure out what data is available and what is required + has_data = self._data is not None + has_data_transpose = self._transpose is not None and not self._transpose_invalid + needs_data = has_data + needs_data_transpose = has_data_transpose + if non_tn_fp8_gemm_supported(): + if rowwise_usage is not None and rowwise_usage: + needs_data = True + if columnwise_usage is not None and columnwise_usage: + needs_data = True + needs_data_transpose = False else: + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self._data = None + if not needs_data_transpose: self._transpose = None self._transpose_invalid = True diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6e3835fbef..843c7936f2 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -66,6 +66,16 @@ def update_quantized( return dst + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: + return False + return True + def make_empty( self, shape: Iterable[int], @@ -207,36 +217,50 @@ def detach(self) -> MXFP8Tensor: # TODO(ksivamani): Fix the detach bug return MXFP8Tensor.make_like(self) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): """ For MXFP8, columnwise scaled output is only produced by x2 scaling kernels, so this function only disables usages. """ - assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor." - - if columnwise_usage and rowwise_usage: - assert ( - self._rowwise_data is not None - and self._rowwise_scale_inv is not None - and self._columnwise_data is not None - and self._columnwise_scale_inv is not None - ), "Cannot update to rowwise and columnwise usage." - return + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data if rowwise_usage: - assert ( - self._rowwise_data is not None and self._rowwise_scale_inv is not None - ), "Cannot update to rowwise usage." + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but MXFP8Tensor is missing column-scaled scale-inverses" + ) + else: self._columnwise_data = None self._columnwise_scale_inv = None - return - - assert ( - self._columnwise_data is not None and self._columnwise_scale_inv is not None - ), "Cannot update to columnwise usage." - self._rowwise_data = None - self._rowwise_scale_inv = None - return def clone(self) -> MXFP8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 00815452da..019aca9f60 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -21,13 +21,10 @@ def prepare_for_saving( """Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save the internal TensorBase types too.""" - # pylint: disable=unidiomatic-typecheck # Using type instead of isinstance to check exact type + tensor_list, tensor_objects_list = [], [] for tensor in tensors: - if tensor is None: - tensor_list.append(None) - tensor_objects_list.append(None) - elif isinstance(tensor, torch.Tensor): + if tensor is None or isinstance(tensor, torch.Tensor): tensor_list.append(tensor) tensor_objects_list.append(None) else: @@ -44,7 +41,7 @@ def restore_from_saved( """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] for tensor in tensors: - if tensor is None: + if tensor is None or isinstance(tensor, torch.Tensor): tensor_objects.append(saved_tensors[0]) saved_tensors = saved_tensors[1:] else: @@ -289,7 +286,11 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): """Indicate to the tensor how it is going to be used This enables optimizations to memory usage in some cases From 09ffb5d99d0b6347a4af499f9e8b6bf339b0a43f Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Thu, 13 Mar 2025 14:05:05 -0700 Subject: [PATCH 202/239] [PyTorch] Support Bgrad Cast FP8 Fusion for FP8 Current Scaling Recipe (#1558) * add tex.bgrad_quantize support for CS Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused import Signed-off-by: Tim Moon --------- Signed-off-by: zhongboz Signed-off-by: Tim Moon Co-authored-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/csrc/extensions/bias.cpp | 23 +++++++++++++++++++ transformer_engine/pytorch/module/base.py | 4 ---- .../pytorch/module/layernorm_mlp.py | 9 +------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index a1fe8bd2b5..5ff10f6efb 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -7,6 +7,7 @@ #include "common.h" #include "pybind.h" #include "transformer_engine/cast.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine::pytorch { @@ -42,6 +43,28 @@ std::vector bgrad_quantize(const at::Tensor& input, py::handle py_qu workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); // Launch kernel + if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(quantizer.get()); + nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(out_tensor.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape); + } nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a44e209d36..4b82054fec 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -35,7 +35,6 @@ from ..tensor import QuantizedTensor, Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer __all__ = ["initialize_ub", "destroy_ub"] @@ -860,9 +859,6 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) - elif isinstance(quantizer, Float8CurrentScalingQuantizer): - # FP8 current scaling does not support fused cast + dbias - grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9bb76cb391..1f167b5a7e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -797,14 +797,7 @@ def backward( ) # activation in high precision if ctx.fp8: - # TODO zhongboz: per-tensor current scaling has no bgrad fusion for now - if isinstance(ctx.grad_fc1_output_quantizer, Float8CurrentScalingQuantizer): - fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) - dact = ctx.grad_fc1_output_quantizer(dact) - else: - fc1_bias_grad, dact = tex.bgrad_quantize( - dact, ctx.grad_fc1_output_quantizer - ) + fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 From 62397359e55824e7df050b8566bdfb629e437e22 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Mar 2025 21:26:07 +0000 Subject: [PATCH 203/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4db8853102..78780c3ac9 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4910,7 +4910,16 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs, use_flash_attn_3] + args += [ + fp8, + fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, + ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) From b9ffe65b9c082a5c2114fef3911f2a9812b3e147 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 18:29:17 -0700 Subject: [PATCH 204/239] remove page_table from IP.step() returns Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 6 +----- transformer_engine/pytorch/inference.py | 18 ++++-------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 78780c3ac9..032d10b0cb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6199,7 +6199,6 @@ def convert_to_torch_float8(tensor, dtype): QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] ) batch_size = cu_seqlens_q.shape[0] - 1 - num_heads_q = query_layer.shape[-2] num_heads_k = key_layer.shape[-2] fa_3_optional_forward_kwargs["q_descale"] = ( query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) @@ -7747,7 +7746,6 @@ def forward( max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) # update KV cache and retrieve saved tokens from cache for inference - page_table = None if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -7768,7 +7766,6 @@ def forward( ( key_layer, value_layer, - page_table, cu_seqlens_q, cu_seqlens_kv, max_seqlen_kv, @@ -8134,8 +8131,7 @@ def forward( alibi_slopes=alibi_slopes, inference_params=inference_params, ) - - return output + return None class MultiheadAttention(torch.nn.Module): diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 2942cf90ca..dd7a9c15d7 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -3,7 +3,6 @@ # See LICENSE for license information. """Inference""" -import os import logging from collections import OrderedDict, defaultdict from typing import Optional, List @@ -50,7 +49,7 @@ def step( qkv_format: str, # pylint: disable=unused-argument ): """Copy the new tokens to KV cache""" - return *self.cache[layer_number], None + return self.cache[layer_number] class InferenceParams: @@ -337,8 +336,6 @@ def step( Full key tensor containing both previous and current key tokens v_cache: torch.Tensor Full value tensor containing both previous and current value tokens - page_table: torch.Tensor - Page table for paged KV cache, [batch_size, max_pages_per_seq]. None for non-paged KV cache. cu_seqlens_q: torch.Tensor Updated cumulative sequence lengths for query, [batch_size + 1] cu_seqlens_kv: torch.Tensor @@ -356,7 +353,7 @@ def step( else: self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - k_cache, v_cache, page_table = self.cache_manager.step( + k_cache, v_cache = self.cache_manager.step( layer_number, new_k, new_v, @@ -368,7 +365,6 @@ def step( return ( k_cache, v_cache, - page_table, self.cu_seqlens_q, self.cu_seqlens_kv, self.max_seqlen_kv, @@ -507,8 +503,6 @@ def step( Full key tensor containing both previous and current key tokens v_cache: torch.Tensor Full value tensor containing both previous and current value tokens - page_table: torch.Tensor - None for non-paged KV cache """ k_cache, v_cache = self.cache[layer_number] @@ -540,7 +534,7 @@ def step( k_cache = k_cache[:batch_size] v_cache = v_cache[:batch_size] - return k_cache, v_cache, None + return k_cache, v_cache class Page: @@ -767,8 +761,6 @@ def step( Full key tensor containing both previous and current key tokens v_cache: torch.Tensor Full value tensor containing both previous and current value tokens - page_table: torch.Tensor - Page table for current iteration, in shape [batch_size, max_pages_per_seq] """ k_cache, v_cache = self.cache[layer_number] @@ -797,6 +789,4 @@ def step( False, ) - page_table = self.page_table[:batch_size] - - return k_cache, v_cache, page_table + return k_cache, v_cache From 430dd5649ac4d250f9de5bdeb320e1719ffeb688 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 18:30:28 -0700 Subject: [PATCH 205/239] fix FP8 FlashAttn DPA fp8_dpa tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 45 +++++++++++---------- transformer_engine/pytorch/attention.py | 5 +++ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 55523276a2..68b6e1bfbd 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -219,8 +219,7 @@ def get_model( qkv_format: str = "bshd", num_layers: int = 1, mode: str = "reference", - fp8_dpa: bool = False, - fp8_mha: bool = False, + is_fp8: bool = False, ): reset_rng_states() sigma = 0.023 @@ -238,13 +237,13 @@ def get_model( fp8_format=recipe.Format.HYBRID, amax_history_len=1, amax_compute_algo="most_recent", - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, + fp8_dpa=is_fp8, + fp8_mha=False, ) if module == "TransformerLayer": hidden_size = config.head_dim_qk * config.num_heads - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): + with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): model = [ TransformerLayer( hidden_size=hidden_size, @@ -258,6 +257,7 @@ def get_model( layer_number=layer_number, kv_channels=config.head_dim_qk, self_attn_mask_type=attn_mask_type, + fuse_qkv_params=False, params_dtype=dtype, attn_input_format=qkv_format, ) @@ -266,7 +266,7 @@ def get_model( for layer_number in range(1, num_layers + 1) ] if module == "DotProductAttention": - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): + with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): model = [ DotProductAttention( kv_channels=config.head_dim_qk, @@ -373,14 +373,14 @@ def generate_args( def get_tols(module, backend, dtype): if module == "TransformerLayer": tols = { - torch.half: 4e-3, - torch.bfloat16: 3.5e-2, + torch.half: (3e-3, 3e-3), + torch.bfloat16: (3e-2, 3e-2), } if module == "DotProductAttention": tols = { - torch.half: 1e-3, - torch.bfloat16: 1e-2, - torch.float8_e4m3fn: 3e-2, + torch.half: (1e-3, 1e-3), + torch.bfloat16: (1e-2, 1e-3), + torch.float8_e4m3fn: (1e-2, 3e-2), } return tols[dtype] @@ -394,6 +394,7 @@ def get_tols(module, backend, dtype): @pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True]) def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): + reset_rng_states() logger = logging.getLogger("test_paged_attn") fp8_recipe = recipe.DelayedScaling( margin=0, @@ -401,7 +402,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda amax_history_len=1, amax_compute_algo="most_recent", fp8_dpa=is_fp8, - fp8_mha=is_fp8, + fp8_mha=False, ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe @@ -480,8 +481,9 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") + # TransformerLayer FP8 TN Gemm currently requires %8=0 if is_fp8 and not ( - qkv_format == "thd" and module == "DotProductAttention" and dtype == torch.float16 + qkv_format == "thd" and module == "DotProductAttention" ): pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") @@ -516,8 +518,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda qkv_format, num_layers, mode="inference", - fp8_dpa=is_fp8, - fp8_mha=is_fp8, + is_fp8=is_fp8, ) # graph the model if necessary @@ -663,29 +664,29 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda incremental_output = incremental_output[0] # compare results - tol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) + atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[i, sim.step_lens[i] - 1, :], - atol=tol, - rtol=tol, + atol=atol, + rtol=rtol, ) if qkv_format == "sbhd": torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[sim.step_lens[i] - 1, i, :], - atol=tol, - rtol=tol, + atol=atol, + rtol=rtol, ) if qkv_format == "thd": torch.testing.assert_close( full_output[seq, sim.t_total_lens[i] - 1, :], incremental_output[cu_seqlens_q[i + 1] - 1, :], - atol=tol, - rtol=tol, + atol=atol, + rtol=rtol, ) sim.t += 1 diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 032d10b0cb..36bd05103b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -528,6 +528,11 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8_meta["recipe"].fp8_mha: + logger.debug("Disabling all backends for KV caching with FP8 MHA") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False if use_flash_attention_3 and q_format != "thd": if _flash_attn_3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") From 10e50f5d2528419035cb62d848cd0c80bae6a835 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Mar 2025 01:31:30 +0000 Subject: [PATCH 206/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_paged_attn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 68b6e1bfbd..d1c428c337 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -482,9 +482,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda if backend == "UnfusedAttention" and is_cuda_graph: pytest.skip("CUDA graph is not supported for UnfusedAttention backend") # TransformerLayer FP8 TN Gemm currently requires %8=0 - if is_fp8 and not ( - qkv_format == "thd" and module == "DotProductAttention" - ): + if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"): pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") # create full model From 7a8c0c518ac0b9ec8e24f9df4fcafb16b2d1656c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 20:19:09 -0700 Subject: [PATCH 207/239] fix CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 36bd05103b..141a8be33f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -141,11 +141,7 @@ def _get_supported_versions(version_min, version_max): try: _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( - "flash-attn v2 is not installed. To use, please install it by" - """ "pip3 install flash-attn".""", - ) + pass # only print warning if use_flash_attention_2 = True in get_attention_backend else: if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version: @@ -194,20 +190,21 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False -_flash_attn_3_installation_steps = """\ +_flash_attn_3_installation_steps_non_cp = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 (5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" +_flash_attn_3_installation_steps_cp = """\ +(1) pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" +(2) python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` +(3) mkdir -p $python_path/flashattn_hopper +(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" try: _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( - "flash-attn v3 is not installed. To use, please install it by \n%s", - _flash_attn_3_installation_steps, - ) + pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_3.flash_attn_interface import ( @@ -947,12 +944,12 @@ def get_attention_backend( logger.warning( "flash-attn v3 may provide important feature support or performance improvement." " Please install flash-attn v3 by \n%s", - _flash_attn_3_installation_steps, + _flash_attn_3_installation_steps_cp if context_parallel else _flash_attn_3_installation_steps_non_cp, ) elif use_flash_attention_2 and not _flash_attn_is_installed: logger.warning( "flash-attn may provide important feature support or performance improvement." - " Please install flash-attn %s.", + " Please install flash-attn %s by pip3 install flash-attn==.", _get_supported_versions( _flash_attn_version_required, _flash_attn_max_version, @@ -3802,6 +3799,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4290,6 +4288,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4816,6 +4815,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -6236,7 +6236,7 @@ def convert_to_torch_float8(tensor, dtype): e.args[0] + ". Please update your flash-attn v3 (beta) installation as it " + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, + + _flash_attn_3_installation_steps_non_cp, ) + e.args[1:] raise From a3bc1b4e8bbb713eaf946e65bee0544340ba9b5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Mar 2025 03:19:56 +0000 Subject: [PATCH 208/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 141a8be33f..cb28ec08bd 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -944,7 +944,11 @@ def get_attention_backend( logger.warning( "flash-attn v3 may provide important feature support or performance improvement." " Please install flash-attn v3 by \n%s", - _flash_attn_3_installation_steps_cp if context_parallel else _flash_attn_3_installation_steps_non_cp, + ( + _flash_attn_3_installation_steps_cp + if context_parallel + else _flash_attn_3_installation_steps_non_cp + ), ) elif use_flash_attention_2 and not _flash_attn_is_installed: logger.warning( From 0fd197f51de2ebcf980f62b4a7bdcd40fad35001 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 21:05:18 -0700 Subject: [PATCH 209/239] minor tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 8 +++--- tests/pytorch/test_numerics.py | 4 ++- transformer_engine/pytorch/graph.py | 13 +++++---- transformer_engine/pytorch/inference.py | 30 +++++++++++---------- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index d1c428c337..b19fb238d9 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -49,11 +49,13 @@ param_types.append(torch.bfloat16) model_configs_infer = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: b, h, hg, d, sq, skv, p, mask, bias "infer_0": ModelConfig( 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), - # "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16), + "infer_1": ModelConfig( + 2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 + ), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -374,7 +376,7 @@ def get_tols(module, backend, dtype): if module == "TransformerLayer": tols = { torch.half: (3e-3, 3e-3), - torch.bfloat16: (3e-2, 3e-2), + torch.bfloat16: (3.5e-2, 3.5e-2), } if module == "DotProductAttention": tols = { diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d4b18ee5f3..93230fa0f3 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2038,7 +2038,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", [False]) # all_boolean) +@pytest.mark.parametrize("use_RoPE", all_boolean) @pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("backend", backends_inference) @@ -2048,6 +2048,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: pytest.skip("FusedAttention and FlashAttention do not support FP32") + if use_RoPE: + pytest.skip("KV cache does not support starting positions for RoPE") os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 4637a8456c..fc466ed267 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -91,6 +91,9 @@ def _make_graphed_callables( sample_args = (sample_args,) sample_kwargs = (sample_kwargs,) + # Check training/inference + is_training = all(c.training for c in callables) + # Check sizes of args if _order is None: assert len(sample_args) == len(callables) @@ -255,7 +258,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: hook.remove() - if callables[0].training: + if is_training: grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -317,7 +320,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) - if callables[0].training: + if is_training: with torch.cuda.graph(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), @@ -334,7 +337,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument grad_idx = 0 for arg in static_input_surface: if ( - callables[0].training + is_training and isinstance(arg, torch.Tensor) and arg.requires_grad ): @@ -374,7 +377,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) - if callables[0].training: + if is_training: with torch.cuda.graph(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), @@ -390,7 +393,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if callables[0].training and isinstance(arg, torch.Tensor) and arg.requires_grad: + if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index dd7a9c15d7..46e961b381 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -54,8 +54,8 @@ def step( class InferenceParams: """ - KV caching mechanism in inference. The memory allocation of the caches, and the copying of - new tokens to the cache take place at the following locations in TransformerLayer.:: + KV caching for inference. The memory allocation of the caches and the copying of new tokens + to the cache take place at the following locations.:: class TransformerLayer: class MultiHeadAttention: @@ -67,20 +67,22 @@ class DotProductAttention: new_k, new_v, qkv_format) output = attention(new_q, k_cache, v_cache, new_qkv_format) - allocate_memory() can be called independently if needed. step() takes 'bshd', 'sbhd' and 'thd' - formats and converts new_k and new_v to 'bshd' in both NonPagedKVCacheManager and PagedKVCacheManager. - Since new_q's format is unchanged, the returned new_qkv_format is 'bshd', 'sbhd_2bshd' and 'thd_2bshd', - respectively. A standard workflow for using InferenceParams to cache KV tokens, is as follows.:: + allocate_memory() can be called outside the model, independently. step() can take three formats, + qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both + NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the + backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}. + A standard KV caching workflow for inference is as follows.:: model = [TransformerLayer() for _ in range(num_layers)] - # initialize InferenceParams, for example, with PagedKVCacheManager + # initialize InferenceParams, e.g. with PagedKVCacheManager inference_params = InferenceParams(..., is_paged=True) - # inference iterations + # inference loop for i in range(num_iters): - # get step info, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1] + # get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1] step_dict = OrderedDict(zip(seq_ids, step_lens)) - # update inference_params state + # update inference_params' state inference_params.pre_step(step_dict) + # run iteration output = model( ..., attn_mask_type="padding_causal", @@ -88,10 +90,10 @@ class DotProductAttention: cu_seqlens_kv=cu_seqlens_new_kv, inference_params=inference_params, ) - # get inference tokens based on qkv_format - # "bshd": output = output[:,step_dict.values()-1] - # "sbhd": output = output[step_dict.values()-1,:] - # "thd" : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1 + # get output tokens based on qkv_format + # 'bshd': output = output[:,step_dict.values()-1] + # 'sbhd': output = output[step_dict.values()-1,:] + # 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1 Parameters From e284346e6b914bedf2e645366e2bfe61ca4e8d80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Mar 2025 04:06:02 +0000 Subject: [PATCH 210/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index fc466ed267..827d43196d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -336,11 +336,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if ( - is_training - and isinstance(arg, torch.Tensor) - and arg.requires_grad - ): + if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: From 7a9f357c1d70d6aa6d70993db226bf9f3e7b4f88 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 21:14:29 -0700 Subject: [PATCH 211/239] update FA3 note and L3 test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L3_pytorch_FA_versions_test/test.sh | 10 ++++++---- transformer_engine/pytorch/attention.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index f57d055db5..febc1aa1ad 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -28,10 +28,12 @@ do then pip3 install flash-attn==${fa_version} else - pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" - python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` - mkdir -p $python_path/flashattn_hopper - wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install + python_path=`python -c "import site; print(site.getsitepackages()[0])"` + mkdir -p $python_path/flash_attn_3 + wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py + cd ../../ fi # Run tests diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index cb28ec08bd..8bebd18afd 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -190,6 +190,8 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False +# FA3 from FA 2.7.3+/hopper has different APIs than from 2.7.2/hopper +# TODO: adopt these new APIs for CP _flash_attn_3_installation_steps_non_cp = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install From 2495c801c922f51e9219c18327b570deb6ed7a67 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 13 Mar 2025 21:25:14 -0700 Subject: [PATCH 212/239] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8bebd18afd..bb79b27c52 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -191,7 +191,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False # FA3 from FA 2.7.3+/hopper has different APIs than from 2.7.2/hopper -# TODO: adopt these new APIs for CP +# we need to adopt these new APIs for CP _flash_attn_3_installation_steps_non_cp = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install From 28d9983a17f94a461fed88ce875f4918a765bb55 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 08:47:55 -0700 Subject: [PATCH 213/239] remove redundant import in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index b19fb238d9..129463b871 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -34,7 +34,6 @@ reset_rng_states, _get_attention_backends, ) -from tests.pytorch.test_numerics import assert_allclose # Initialize RNG state seed = 1234 From dc40f9fdae11e429d8f5420681da147a0859899e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 15 Mar 2025 01:38:19 +0800 Subject: [PATCH 214/239] Update FE to 1.11 (#1580) update FE to 1.11 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 91b7532f33..20c28ea798 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a +Subproject commit 20c28ea798fe99e31d7274e009ee2fbf0e88abfd From 12c3e323773d639319f917d07c4561e2424010fc Mon Sep 17 00:00:00 2001 From: hx Date: Fri, 14 Mar 2025 10:56:56 -0700 Subject: [PATCH 215/239] Fix import error on CPU only devices (#1578) fix cpu device import error Signed-off-by: Hongxiao Bai Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/triton/permutation.py | 129 +++++++++++------- 1 file changed, 77 insertions(+), 52 deletions(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 4ed92b0c80..1c5fd73581 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -109,16 +109,6 @@ def make_row_id_map( return row_id_map -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], -) @triton.jit def _permute_kernel( # pointers @@ -164,6 +154,21 @@ def _permute_kernel( cur_pos += BLOCK_SIZE +try: + _permute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_permute_kernel) +except RuntimeError: + pass + + def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -201,16 +206,6 @@ def permute_with_mask_map( return output, permuted_probs -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], -) @triton.jit def _unpermute_kernel( # pointers @@ -297,6 +292,21 @@ def _unpermute_kernel( current_start += BLOCK_SIZE +try: + _unpermute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_unpermute_kernel) +except RuntimeError: + pass + + def unpermute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -348,16 +358,6 @@ def unpermute_with_mask_map( return output, unpermuted_probs -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], -) @triton.jit def _unpermute_bwd_with_merging_probs_kernel( # pointers @@ -450,6 +450,21 @@ def _unpermute_bwd_with_merging_probs_kernel( tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) +try: + _unpermute_bwd_with_merging_probs_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_unpermute_bwd_with_merging_probs_kernel) +except RuntimeError: + pass + + def unpermute_with_mask_map_bwd_with_merging_probs( fwd_output_grad: torch.Tensor, row_id_map: torch.Tensor, @@ -500,16 +515,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( return act_grad, merging_probs_grad -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], -) @triton.jit def _sort_chunks_by_idxs_kernel( # pointers @@ -589,6 +594,21 @@ def _sort_chunks_by_idxs_kernel( tl.store(permuted_probs_ptr + permuted_prob_off, prob) +try: + _sort_chunks_by_idxs_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_sort_chunks_by_idxs_kernel) +except RuntimeError: + pass + + def sort_chunks_by_idx( inp: torch.Tensor, split_sizes: torch.Tensor, @@ -628,18 +648,8 @@ def sort_chunks_by_idx( return output, row_id_map, permuted_probs -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], -) @triton.jit -def _sort_chunks_by_map( +def _sort_chunks_by_map_kernel( # pointers input_ptr, output_ptr, @@ -677,6 +687,21 @@ def _sort_chunks_by_map( tl.store(permuted_probs_ptr + permuted_prob_off, prob) +try: + _sort_chunks_by_map_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_sort_chunks_by_map_kernel) +except RuntimeError: + pass + + def sort_chunks_by_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -691,7 +716,7 @@ def sort_chunks_by_map( else: permuted_probs = None grid = (num_tokens,) - _sort_chunks_by_map[grid]( + _sort_chunks_by_map_kernel[grid]( inp, output, row_id_map, From c257bf316a96561225fdabe6e80a4d2fd74fa317 Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Fri, 14 Mar 2025 13:29:21 -0700 Subject: [PATCH 216/239] Blackwell devel commoverlap mlperftests (#1529) * Add options to comm overlap tests Signed-off-by: Vasudevan Rengasamy * Fix Typo Signed-off-by: Vasudevan Rengasamy * Update tests/pytorch/distributed/run_layer_with_overlap.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- setup.py | 2 +- .../distributed/run_layer_with_overlap.py | 167 ++++++++++++++---- 2 files changed, 136 insertions(+), 33 deletions(-) diff --git a/setup.py b/setup.py index 996027bd9e..13e8b6ee83 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: install_reqs.extend(["torch>=2.1"]) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") - test_reqs.extend(["numpy", "torchvision", "prettytable"]) + test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) # test_reqs.extend(["numpy", "praxis"]) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 39200775c9..3526ad812f 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -11,6 +11,7 @@ import argparse import warnings import pprint +import yaml import torch import torch.distributed as dist @@ -46,7 +47,7 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): input_shape = [config.seq_length, config.batch_size, hidden_size] args = [hidden_size] kwargs = { - "params_dtype": torch.float32, + "params_dtype": torch.float32 if not config.use_bf16_params else torch.bfloat16, "device": "cuda", "tp_group": tp_group, "tp_size": tp_size, @@ -59,11 +60,18 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): if config.linear_parallel_mode == "row": input_shape[-1] = ffn_hidden_size // tp_size args = [ffn_hidden_size, hidden_size] + if config.in_features is not None: + input_shape[-1] = config.in_features // tp_size + args = [config.in_features, hidden_size] kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2" + kwargs["ub_name"] = kwargs["ub_name"] if config.ub_name is None else config.ub_name elif config.linear_parallel_mode == "column": input_shape[0] = config.seq_length // tp_size - args.append(qkv_size) - kwargs["ub_name"] = "qkv" + if config.out_features is not None: + args.append(config.out_features) + else: + args.append(qkv_size) + kwargs["ub_name"] = "qkv" if config.ub_name is None else config.ub_name kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference @@ -87,6 +95,9 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + if config.ub_cfg is not None and isinstance(config.ub_cfg, str): + with open(config.ub_cfg, "r") as stream: + config.ub_cfg = yaml.safe_load(stream) return args, kwargs, input_shape @@ -103,6 +114,30 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." ) + parser.add_argument( + "--in-features", + type=int, + default=None, + help="Optional input feature size for weight. Only used for Linear layer.", + ) + parser.add_argument( + "--out-features", + type=int, + default=None, + help="Optional output feature size for weight. Only used for LayerNormLinear layer.", + ) + parser.add_argument( + "--tp", + type=int, + default=None, + help="Optional tensor_model_parallel_size used to initialize UB.", + ) + parser.add_argument( + "--use-bf16-params", + action="store_true", + default=False, + help="Use BF16 params instead of FP32.", + ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." @@ -132,6 +167,28 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." ) + parser.add_argument( + "--ub-cfg", type=str, default=None, help="Optional TP config yaml file input." + ) + parser.add_argument("--ub-name", type=str, default=None, help="Optional TP layer name.") + parser.add_argument( + "--skip-verify", + action="store_true", + default=False, + help="Skip numerics check.", + ) + parser.add_argument( + "--benchmark", + action="store_true", + default=False, + help="Benchmark comm-gemm overlap perf.", + ) + parser.add_argument( + "--benchmark-iter", + type=int, + default=100, + help="Number of iterations for benchmarking perf.", + ) parser.add_argument( "--linear-parallel-mode", type=str.lower, @@ -223,9 +280,36 @@ def _train(opts): shell=True, ) - if result.stdout == "0": # Extra checks for non-MNNVL platforms + if result.stdout == "0" and opts.tp is None: # Extra checks for non-MNNVL platforms assert WORLD_SIZE == LOCAL_SIZE + # Initialize torch.distributed tp process group + new_group_kwargs = { + "backend": "nccl", + } + if opts.tp is not None: + LOCAL_SIZE = opts.tp + tp_base_rank = (WORLD_RANK // LOCAL_SIZE) * LOCAL_SIZE + tp_rank_list = list(range(tp_base_rank, tp_base_rank + LOCAL_SIZE)) + new_group_kwargs = { + "backend": "nccl", + "ranks": tp_rank_list, + } + else: + opts.tp = WORLD_SIZE + + # Tensor dim overrides for tensors that do not require TP communication + if opts.in_features is not None: + assert opts.layer_type is te.Linear and opts.linear_parallel_mode == "row", ( + "--in-features is only used to configure row-tensor-parallel Linear layers. Use" + " --num-heads or --head-dim for other cases." + ) + if opts.out_features is not None: + assert opts.layer_type is te.LayerNormLinear and opts.linear_parallel_mode == "column", ( + "--out-features is only used to configure column-tensor-parallel LayerNormLinear" + " layers. Use --num-heads or --head-dim for other cases." + ) + def dist_print(msg, src=None, end="\n", debug=False, error=False): if debug and not opts.debug: return @@ -253,9 +337,11 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) - nccl_world = dist.new_group(backend="nccl") + nccl_world = dist.new_group(**new_group_kwargs) dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + # Initialize the Transformer Engine layer with overlap + args, kwargs, input_shape = _get_layer_args(opts, nccl_world, opts.tp) # Intialize userbuffers ub_cfgs = None if opts.overlap_rs_dgrad: @@ -265,15 +351,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], - WORLD_SIZE, + opts.tp, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, - ub_cfgs=ub_cfgs, + ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, ) - # Initialize the Transformer Engine layer with overlap - args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE) with te.fp8_model_init(enabled=opts.fp8_init): test_model = opts.layer_type(*args, **kwargs) dist_print("Initialized test model...", debug=True) @@ -283,7 +367,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist.barrier() # Initialize the reference model and copy all parameters - ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) + ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, opts.tp, reference=True) with te.fp8_model_init(enabled=opts.fp8_init): ref_model = opts.layer_type(*ref_args, **ref_kwargs) dist_print("Initialized reference model...", debug=True) @@ -326,7 +410,8 @@ def run_fwd_bwd(model, x): with torch.cuda.graph(test_graph): test_out = run_fwd_bwd(test_model, test_x) test_graph.replay() - del test_graph + if not opts.benchmark: + del test_graph else: test_out = run_fwd_bwd(test_model, test_x) test_grads = [test_out, test_x.grad] @@ -351,28 +436,46 @@ def run_fwd_bwd(model, x): if ref_param.requires_grad and "layer_norm" not in ref_name: ref_grads.append(ref_param.grad) - # Make sure we have the same number of gradients numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") - if len(test_grads) != len(ref_grads): - numerics_failed[0] = 1 - numerics_info = ( - "NUMERICAL CHECK FAILED: Incorrect number of gradients, " - + f"expected {len(ref_grads)} but got {len(test_grads)}." - ) - dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - - # Now validate accuracy - if not bool(numerics_failed.item()): - for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 - grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) - dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[0] = int(grad_failed) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()) and not opts.debug: - break + if not opts.skip_verify: + # Make sure we have the same number of gradients + if len(test_grads) != len(ref_grads): + numerics_failed[0] = 1 + numerics_info = ( + "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + + f"expected {len(ref_grads)} but got {len(test_grads)}." + ) + dist_print(numerics_info, src=WORLD_RANK, error=True) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + + # Now validate accuracy + if not bool(numerics_failed.item()): + for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): + rtol = 0.125 if opts.fp8 else 0.025 + atol = 0.0625 if opts.fp8 else 0.00125 + grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) + dist_print(grad_info, src=WORLD_RANK, error=grad_failed) + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + if bool(numerics_failed.item()) and not opts.debug: + break + + if opts.benchmark: + # Warmup to not profile CPU overhead + for _ in range(100): + if opts.use_cuda_graphs: + test_graph.replay() + else: + test_out = run_fwd_bwd(test_model, test_x) + torch.cuda.cudart().cudaProfilerStart() + for _ in range(opts.benchmark_iter): + if opts.use_cuda_graphs: + test_graph.replay() + else: + test_out = run_fwd_bwd(test_model, test_x) + torch.cuda.cudart().cudaProfilerStop() + if opts.use_cuda_graphs: + del test_graph te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) From 496776b2a587018b558a74673fafc718fcf3ba21 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 15:02:39 -0700 Subject: [PATCH 217/239] adopt new FA3 APIs from FA2.7.3+/hopper for CP and non-CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 397 ++++++++++++++---------- 1 file changed, 240 insertions(+), 157 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bb79b27c52..ce41722e06 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -190,19 +190,14 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_is_installed = False _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False -# FA3 from FA 2.7.3+/hopper has different APIs than from 2.7.2/hopper -# we need to adopt these new APIs for CP -_flash_attn_3_installation_steps_non_cp = """\ +# FA3 from FA 2.7.3+/hopper has different APIs than FA3 from 2.7.2/hopper +# Please follow these instructions to install FA3 +_flash_attn_3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 (5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" -_flash_attn_3_installation_steps_cp = """\ -(1) pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" -(2) python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` -(3) mkdir -p $python_path/flashattn_hopper -(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" try: _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: @@ -946,11 +941,7 @@ def get_attention_backend( logger.warning( "flash-attn v3 may provide important feature support or performance improvement." " Please install flash-attn v3 by \n%s", - ( - _flash_attn_3_installation_steps_cp - if context_parallel - else _flash_attn_3_installation_steps_non_cp - ), + _flash_attn_3_installation_steps, ) elif use_flash_attention_2 and not _flash_attn_is_installed: logger.warning( @@ -1903,6 +1894,92 @@ def _get_cu_seqlens_info_with_cp( return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] +def get_fa_args( + forward: bool, + use_flash_attn_3: bool, + qkv_format: str, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + dq=None, + dk=None, + dv=None, + ): + if use_flash_attn_3: + if forward: + if qkv_format == "thd": + return [ + *[None] * 4, # k_new, v_new, qv, out + cu_seqlens_q, + cu_seqlens_kv, + *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + else: + return [ + *[None] * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + else: + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + else: + return [ + None, # cu_seqlens_q + None, # cu_seqlens_kv + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + else: + if forward: + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + else: + return [] + else: + if qkv_format == "thd": + return [ + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + else: + return [ + dq, + dk, + dv, + ] + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -2119,7 +2196,7 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or use_flash_attn_3: + if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 @@ -2272,14 +2349,15 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - fa_forward_args_thd = [] - if qkv_format == "thd": - fa_forward_args_thd = [ - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, - ] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -2402,14 +2480,15 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - fa_forward_args_thd = [] - if qkv_format == "thd": - fa_forward_args_thd = [ - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv // 2, - ] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv // 2, + ) if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -2548,14 +2627,15 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - fa_forward_args_thd = [] - if qkv_format == "thd": - fa_forward_args_thd = [ - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q // 2, - max_seqlen_kv, - ] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q // 2, + max_seqlen_kv=max_seqlen_kv, + ) if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -2670,14 +2750,15 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - fa_forward_args_thd = [] - if qkv_format == "thd": - fa_forward_args_thd = [ - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, - ] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) fa_outputs = flash_attn_fwd( q, ( @@ -2903,7 +2984,6 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") - use_flash_attn_3 = ctx.use_flash_attn_3 cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a @@ -3068,7 +3148,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if use_flash_attn_3: + if ctx.use_flash_attn_3: flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) @@ -3209,22 +3289,26 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - fa_backward_args_thd = [] - if ctx.qkv_format == "thd": - fa_backward_args_thd = [ - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ] - if use_flash_attn_3 or ( + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_backward_kwargs["window_size"] = (-1, 0) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 - if not use_flash_attn_3: + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3233,9 +3317,6 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, - dq_, - dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], *fa_backward_args_thd, causal=True, **fa_backward_kwargs, @@ -3326,22 +3407,26 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - fa_backward_args_thd = [] - if ctx.qkv_format == "thd": - fa_backward_args_thd = [ - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, - ] - if use_flash_attn_3 or ( + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv // 2, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_backward_kwargs["window_size"] = (-1, -1) - if _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not use_flash_attn_3: + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3350,9 +3435,6 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, - dq_, - dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], *fa_backward_args_thd, causal=False, **fa_backward_kwargs, @@ -3445,22 +3527,26 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - fa_backward_args_thd = [] - if ctx.qkv_format == "thd": - fa_backward_args_thd = [ - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, - ] - if use_flash_attn_3 or ( + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q // 2, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not use_flash_attn_3: + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3469,9 +3555,6 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse_, - dq_, - dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], *fa_backward_args_thd, causal=False, **fa_backward_kwargs, @@ -3541,20 +3624,24 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q) dkv_ = torch.empty_like(kv) - fa_backward_args_thd = [] - if ctx.qkv_format == "thd": - fa_backward_args_thd = [ - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ] - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not use_flash_attn_3: + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout, @@ -3563,9 +3650,6 @@ def backward(ctx, dout): kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], out, softmax_lse, - dq_, - dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], *fa_backward_args_thd, causal=False, **fa_backward_kwargs, @@ -4006,14 +4090,15 @@ def forward( window_size=window_size_per_step[i], ) else: - fa_forward_args_thd = [] - if qkv_format == "thd": - fa_forward_args_thd = [ - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv_, - ] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv_, + ) if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -4090,7 +4175,6 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") - use_flash_attn_3 = ctx.use_flash_attn_3 cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) @@ -4139,7 +4223,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if use_flash_attn_3: + if ctx.use_flash_attn_3: flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: @@ -4201,19 +4285,23 @@ def backward(ctx, dout): dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] ] - fa_backward_args_thd = [] - if ctx.qkv_format == "thd": - fa_backward_args_thd = [ - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - ctx.max_seqlen_q, - max_seqlen_kv, - ] - if not use_flash_attn_3: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + dq=dq_per_step[i], + dk=dk_per_step[i], + dv=dv_per_step[i], + ) + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: + if ctx.use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = window_size_per_step[i] - if _flash_attn_2_7_0_plus: + elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( @@ -4223,9 +4311,6 @@ def backward(ctx, dout): v_, out_, softmax_lse_per_step[i], - dq_per_step[i], - dk_per_step[i], - dv_per_step[i], *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, @@ -4294,7 +4379,6 @@ def backward(ctx, dout): None, None, None, - None, ) @@ -4366,7 +4450,7 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size"] = window_size elif _flash_attn_2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size[0] @@ -4471,14 +4555,15 @@ def forward( if fp8: out = out._data else: - fa_forward_args_thd = [] - if qkv_format == "thd": - fa_forward_args_thd = [ - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) fa_outputs = flash_attn_fwd( q, k, @@ -4583,7 +4668,6 @@ def forward( def backward(ctx, dout): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") - use_flash_attn_3 = ctx.use_flash_attn_3 cp_size = get_distributed_world_size(ctx.cp_group) ( @@ -4665,7 +4749,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if use_flash_attn_3: + if ctx.use_flash_attn_3: flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) @@ -4677,7 +4761,7 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size"] = ctx.window_size elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = ctx.window_size[0] @@ -4745,15 +4829,19 @@ def backward(ctx, dout): else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] - fa_backward_args_thd = [] - if ctx.qkv_format == "thd": - fa_backward_args_thd = [ - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ] - if not use_flash_attn_3: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq, + dk=dk, + dv=dv, + ) + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( dout, @@ -4762,9 +4850,6 @@ def backward(ctx, dout): v, out, softmax_lse, - dq, - dk, - dv, *fa_backward_args_thd, causal=causal, **fa_backward_kwargs, @@ -4820,8 +4905,6 @@ def backward(ctx, dout): None, None, None, - None, - None, ) @@ -6242,7 +6325,7 @@ def convert_to_torch_float8(tensor, dtype): e.args[0] + ". Please update your flash-attn v3 (beta) installation as it " + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps_non_cp, + + _flash_attn_3_installation_steps, ) + e.args[1:] raise @@ -6923,12 +7006,12 @@ def forward( if qkv_format in ["sbhd", "bshd"]: if qkv_format == "sbhd": batch_size = query_layer.shape[1] - max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + max_seqlen_q = query_layer.shape[0] + max_seqlen_kv = key_layer.shape[0] if qkv_format == "bshd": batch_size = query_layer.shape[0] - max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv + max_seqlen_q = query_layer.shape[1] + max_seqlen_kv = key_layer.shape[1] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size if "padding" in attn_mask_type: From 7f1c7655f933c37ceebd38c326c3d661cdfcd704 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 15:10:23 -0700 Subject: [PATCH 218/239] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 114 +++++++++++------------- 1 file changed, 54 insertions(+), 60 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ce41722e06..fe08e7477c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1906,6 +1906,7 @@ def get_fa_args( dk=None, dv=None, ): + """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: if forward: if qkv_format == "thd": @@ -1918,66 +1919,59 @@ def get_fa_args( max_seqlen_kv, *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale ] - else: - return [ - *[None] * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k - max_seqlen_q, - max_seqlen_kv, - *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale - ] - else: - if qkv_format == "thd": - return [ - cu_seqlens_q, - cu_seqlens_kv, - None, # sequed_q - None, # sequed_k - max_seqlen_q, - max_seqlen_kv, - dq, - dk, - dv, - ] - else: - return [ - None, # cu_seqlens_q - None, # cu_seqlens_kv - None, # sequed_q - None, # sequed_k - max_seqlen_q, - max_seqlen_kv, - dq, - dk, - dv, - ] - else: - if forward: - if qkv_format == "thd": - return [ - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ] - else: - return [] - else: - if qkv_format == "thd": - return [ - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ] - else: - return [ - dq, - dk, - dv, - ] + return [ + *[None] * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + return [ + None, # cu_seqlens_q + None, # cu_seqlens_kv + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + if forward: + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [] + if qkv_format == "thd": + return [ + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [ + dq, + dk, + dv, + ] class AttnFuncWithCPAndKVP2P(torch.autograd.Function): From 0cf5c0d8222c899f5ac8aee6f05d4fb136c9b817 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Mar 2025 22:11:23 +0000 Subject: [PATCH 219/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 79 +++++++++++++++++-------- 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fe08e7477c..da55b73c5d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1905,7 +1905,7 @@ def get_fa_args( dq=None, dk=None, dv=None, - ): +): """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: if forward: @@ -1917,13 +1917,16 @@ def get_fa_args( *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k max_seqlen_q, max_seqlen_kv, - *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale ] return [ - *[None] * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k + *[None] + * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k max_seqlen_q, max_seqlen_kv, - *[None] * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale ] if qkv_format == "thd": return [ @@ -2351,7 +2354,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - ) + ) fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -2482,7 +2485,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv // 2, - ) + ) if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -2629,7 +2632,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q // 2, max_seqlen_kv=max_seqlen_kv, - ) + ) if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -2752,7 +2755,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - ) + ) fa_outputs = flash_attn_fwd( q, ( @@ -3292,9 +3295,17 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], - ) + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) if ctx.use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -3410,9 +3421,17 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv // 2, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], - ) + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) if ctx.use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -3530,9 +3549,17 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q // 2, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], - ) + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) if ctx.use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -3629,8 +3656,10 @@ def backward(ctx, dout): dq=dq_, dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], - ) - if ctx.use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + ) + if ctx.use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -4092,7 +4121,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, - ) + ) if use_flash_attn_3 or ( _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus ): @@ -4290,10 +4319,12 @@ def backward(ctx, dout): dq=dq_per_step[i], dk=dk_per_step[i], dv=dv_per_step[i], - ) + ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if ctx.use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if ctx.use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = window_size_per_step[i] elif _flash_attn_2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -4557,7 +4588,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - ) + ) fa_outputs = flash_attn_fwd( q, k, @@ -4834,7 +4865,7 @@ def backward(ctx, dout): dq=dq, dk=dk, dv=dv, - ) + ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( From 5578b699a684c6775df8f6cf65d28ae6fc1a25c9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 16:12:47 -0700 Subject: [PATCH 220/239] relax tols for TransformerLayers Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 129463b871..d23b9da897 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -374,7 +374,7 @@ def generate_args( def get_tols(module, backend, dtype): if module == "TransformerLayer": tols = { - torch.half: (3e-3, 3e-3), + torch.half: (4e-3, 4e-3), torch.bfloat16: (3.5e-2, 3.5e-2), } if module == "DotProductAttention": From 373394789c5861d40c9ec51a2c2c2cca22b22ba0 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Fri, 14 Mar 2025 16:53:23 -0700 Subject: [PATCH 221/239] Refactoring attention.py part 1 (#1542) * Create pytorch/dot_product_attention module and pytorch/d_p_a/utils.py Move attention logging into a separate class in pytorch/d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani * Create FlashAttentionUtils class in pytorch/d_p_a/utils/py for versioning info Move versioning info out of pytorch/attention.py Signed-off-by: Kshitij Janardan Lakhani * Move AttentionParams and get_attention_backend from attention.py to d_p_a/utils.py Fix tests and imports for the above refactor change Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move get_qkv_layout(), get_full_mask(), get_alibi(), get_attention_quantizers() to d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move tensor packing and unpacking helper functions from pyt/attention.py to d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move cumulative seqlens and indices methods from pyt/attention.py to d_p_a/utils.py Rename cumulative functions from using _cu_ to using _cumul_ to differentiate from CUDA cu calls protocol Rename tensor packaging methods with leading underscore to make them as internal to file Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary imports in pytorch/attention.py and d_p_a/utils.py Signed-off-by: Kshitij Janardan Lakhani * Create d_p_a/inference.py and move InferenceParams from pyt/attention.py to it Modify tests and other files to import InferenceParams correctly Signed-off-by: Kshitij Janardan Lakhani Modify docs api for InferenceParams Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Create d_p_a/rope.py and move RoPE methods from pytorch/attention.py to it Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code cleanup Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix qa testing induced bug Code clean up Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix incorrect pack_tensor arg type Code clean up Signed-off-by: Kshitij Janardan Lakhani * nit: Resolve lint errors Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove typedef FAUtils for FlashAttentionUtils Use attn_log instead of att_log Signed-off-by: Kshitij Janardan Lakhani Fix lint error Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Fix the function name from get_cumul to the earlier get_cu Signed-off-by: Kshitij Janardan Lakhani * nit: Fix typos, explicit imports and remove extra comments Signed-off-by: Kshitij Janardan Lakhani --------- Signed-off-by: Kshitij Janardan Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- docs/api/pytorch.rst | 2 +- docs/examples/attention/attention.ipynb | 2 +- docs/examples/te_llama/te_llama.py | 2 +- tests/pytorch/fused_attn/test_fused_attn.py | 59 +- .../fused_attn/test_fused_attn_with_cp.py | 8 +- tests/pytorch/test_fused_rope.py | 2 +- tests/pytorch/test_numerics.py | 2 +- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/attention.py | 2199 ++--------------- .../pytorch/dot_product_attention/__init__.py | 5 + .../dot_product_attention/inference.py | 53 + .../pytorch/dot_product_attention/rope.py | 225 ++ .../pytorch/dot_product_attention/utils.py | 1639 ++++++++++++ transformer_engine/pytorch/transformer.py | 4 +- 14 files changed, 2182 insertions(+), 2022 deletions(-) create mode 100644 transformer_engine/pytorch/dot_product_attention/__init__.py create mode 100644 transformer_engine/pytorch/dot_product_attention/inference.py create mode 100644 transformer_engine/pytorch/dot_product_attention/rope.py create mode 100644 transformer_engine/pytorch/dot_product_attention/utils.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 67a123d334..ca4bd91420 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -31,7 +31,7 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) :members: forward, set_context_parallel_group, set_tensor_parallel_group -.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length) +.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length) .. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker() :members: reset, get_states, set_states, add, fork diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 16a3b05466..d20cd5c74e 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -458,7 +458,7 @@ " \n", "
    cuDNN 8.9.6+: sm90
    JAX, PaddlePaddle: `no_bias`, `post_scale_bias`JAX: `no_bias`, `post_scale_bias`ALiBi slopes: FP32cuDNN 9.0+: sm80+
    \n", "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
    \n", "Note\n", diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 5a40a62da7..3ddf7f411a 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -11,7 +11,7 @@ from torch import nn import transformer_engine as te -from transformer_engine.pytorch.attention import RotaryPositionEmbedding +from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.fp8 import fp8_model_init import transformers diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index ff45d1e38f..d7585899a7 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -18,15 +18,15 @@ from transformer_engine.pytorch.attention import ( DotProductAttention, MultiheadAttention, - RotaryPositionEmbedding, + _attention_backends, +) +from transformer_engine.pytorch.dot_product_attention.utils import ( + FlashAttentionUtils, get_attention_backend, - _flash_attn_is_installed, - _flash_attn_2_3_plus, - _flash_attn_3_is_installed, check_set_window_size, AttentionParams, - _attention_backends, ) +from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.constants import TE_DType import transformer_engine.pytorch.cpp_extensions as ext from transformer_engine.pytorch.cpp_extensions.fused_attn import ( @@ -191,9 +191,20 @@ def test(): fp8=fp8, fp8_meta=fp8_meta, ) - _, _, fused_attention_backend, _, available_backends = get_attention_backend( - attention_params - ) + ( + use_flash_attention, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) = get_attention_backend(attention_params) + # Set attention.py _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False return available_backends, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} @@ -269,12 +280,12 @@ def test_dot_product_attention( # mannually pads and unpads the input and output of FlashAttention for testing purposes if ( pad_between_seqs - and _flash_attn_is_installed + and FlashAttentionUtils.is_installed and not ( config.max_seqlen_q != config.max_seqlen_kv and config.attn_mask_type in ["causal", "padding_causal"] ) - and (config.window_size[0] == -1 or _flash_attn_2_3_plus) + and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) ): flash_attn_supported = True @@ -581,7 +592,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): } -@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.") +@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) @@ -603,7 +614,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): } -@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.") +@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @@ -1445,7 +1456,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ): pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1471,7 +1486,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1656,7 +1675,11 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1685,7 +1708,11 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 96321043bc..303c39e6c0 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -7,15 +7,11 @@ import pytest import torch -from transformer_engine.pytorch.attention import ( - _flash_attn_2_plus, - _flash_attn_2_3_plus, -) from transformer_engine.pytorch.utils import ( get_device_compute_capability, get_cudnn_version, ) - +from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils from test_fused_attn import ModelConfig model_configs_flash_attn = { @@ -54,7 +50,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args -@pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.") +@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 7ad2c93aa5..e236a29a9d 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -5,7 +5,7 @@ import pytest import torch from typing import Callable, Tuple, Union -from transformer_engine.pytorch.attention import ( +from transformer_engine.pytorch.dot_product_attention.rope import ( RotaryPositionEmbedding, apply_rotary_pos_emb, ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 5bec7f7c7f..914ced130a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -34,10 +34,10 @@ RMSNorm, TransformerLayer, LayerNorm, - InferenceParams, Fp8Padding, Fp8Unpadding, ) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 166e72506b..5f20dbff85 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -89,8 +89,8 @@ def _load_library(): from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.attention import DotProductAttention -from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.permutation import ( moe_permute, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5e5d4098b6..5865ba2df1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -12,17 +12,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging -import functools -from dataclasses import dataclass, fields import numpy as np from packaging.version import Version as PkgVersion import torch -import torch.nn.functional as F import transformer_engine_torch as tex -import transformer_engine as te from transformer_engine.pytorch.utils import ( get_cudnn_version, nvtx_range_pop, @@ -31,18 +27,9 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( fused_attn_fwd, fused_attn_bwd, - QKVLayout, - AttnBiasType, - AttnMaskType, FusedAttnBackend, META_QKV, - META_DQKV, META_O, - META_DO, - META_S, - META_DP, - META_O_CP, - META_DQKV_CP, ) from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, @@ -87,47 +74,18 @@ restore_from_saved, ) +# Import attention utils +import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils +from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log +from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] -_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") -_stream_handler = logging.StreamHandler() -_stream_handler.setFormatter(_formatter) -fa_logger = logging.getLogger(__name__) -fa_logger.setLevel(_log_level) -if not fa_logger.hasHandlers(): - fa_logger.addHandler(_stream_handler) - - -@functools.lru_cache(maxsize=None) -def _get_supported_versions(version_min, version_max): - return ">= " + str(version_min) + ", " + "<= " + str(version_max) - - -_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) -_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) -_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - -# Detect flash-attn v2 in the environment -_flash_attn_is_installed = False -_flash_attn_version = PkgVersion("0") -_flash_attn_version_required = PkgVersion("2.1.1") -_flash_attn_version_required_blackwell = PkgVersion("2.7.3") -_flash_attn_max_version = PkgVersion("2.7.4.post1") -_flash_attn_2_plus = False -_flash_attn_2_1_plus = False -_flash_attn_2_3_plus = False -_flash_attn_2_4_plus = False -_flash_attn_2_4_1_plus = False -_flash_attn_2_5_7_plus = False -_flash_attn_2_6_0_plus = False -_flash_attn_2_7_0_plus = False +# Setup Attention Logging +attn_log.setup_logging() + +# Global vars for flash attn imports flash_attn_cuda_bwd = None flash_attn_func = None flash_attn_varlen_func = None @@ -135,23 +93,26 @@ def _get_supported_versions(version_min, version_max): _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None - try: - _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) + fa_utils.version = PkgVersion(get_pkg_version("flash-attn")) except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( + if ( + torch.cuda.is_available() + and get_device_compute_capability() >= (8, 0) + and dpa_utils._NVTE_FLASH_ATTN + ): + attn_log.fa_logger.debug( "flash-attn v2 is not installed. To use, please install it by" """ "pip3 install flash-attn".""", ) else: if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): - if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version: - _flash_attn_is_installed = True - elif _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: - _flash_attn_is_installed = True + if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version: + fa_utils.is_installed = True + elif fa_utils.version_required <= fa_utils.version <= fa_utils.max_version: + fa_utils.is_installed = True - if _flash_attn_is_installed: + if fa_utils.is_installed: from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd @@ -163,51 +124,40 @@ def _get_supported_versions(version_min, version_max): _flash_attn_varlen_backward as _flash_attn_varlen_bwd, ) - _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") - _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") - _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") - _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") - _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") - _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") - _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") - _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0") + # Setup Flash attention utils + fa_utils.set_flash_attention_version() elif ( - torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN + torch.cuda.is_available() + and get_device_compute_capability() >= (8, 0) + and dpa_utils._NVTE_FLASH_ATTN ): - fa_logger.warning( + attn_log.fa_logger.warning( "Supported flash-attn versions are %s. Found flash-attn %s.", - _get_supported_versions( + dpa_utils._get_supported_versions( ( - _flash_attn_version_required + fa_utils.version_required if get_device_compute_capability() < (10, 0) - else _flash_attn_version_required_blackwell + else fa_utils.version_required_blackwell ), - _flash_attn_max_version, + fa_utils.max_version, ), - _flash_attn_version, + fa_utils.version, ) # Detect flash-attn v3 in the environment # This section will be removed when FA3 is released as a regular FA package, # i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0 -_flash_attn_3_is_installed = False -_flash_attn_3_version = PkgVersion("0") -_flash_attn_3_0_0_beta = False -_use_flash_attn_3 = False -# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved -# https://github.com/Dao-AILab/flash-attention/issues/1452 -_flash_attn_3_installation_steps = """\ -(1) pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" -(2) python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` -(3) mkdir -p $python_path/flashattn_hopper -(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" try: - _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) + fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper")) except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( + if ( + torch.cuda.is_available() + and get_device_compute_capability() >= (9, 0) + and dpa_utils._NVTE_FLASH_ATTN + ): + attn_log.fa_logger.debug( "flash-attn v3 is not installed. To use, please install it by \n%s", - _flash_attn_3_installation_steps, + fa_utils.v3_installation_steps, ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 @@ -223,10 +173,9 @@ def _get_supported_versions(version_min, version_max): _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, ) - _flash_attn_3_is_installed = True - _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") - _use_flash_attn_3 = True + fa_utils.set_flash_attention_3_params() +# Global vars for available attention backends and ALiBi cache _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -236,107 +185,6 @@ def _get_supported_versions(version_min, version_max): "backend_selection_requires_update": False, } - -@dataclass(eq=True) -class AttentionParams: - """ - Attention parameters used to determine which backend to be used. - - Parameters - ---------- - qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` - Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. - qkv_dtype: torch.dtype, default = `torch.bfloat16` - Data type of query/key/value tensors. - qkv_layout: str, default = "sbh3d" - Query/key/value tensor memory layout. - batch_size: int, default = 1 - Batch size. - num_heads: int, default = 16 - Number of attention heads in the query tensor. - num_gqa_groups: int, default = 16 - Number of attention heads in key and value tensors. - max_seqlen_q: int, default = 128 - Maximum sequence length of the query tensor. - max_seqlen_kv: int, default = 128 - Maximum sequence length of the key and value tensors. - head_dim_qk: int, default = 64 - The size of each attention head in query and key tensors. - head_dim_v: int, default = 64 - The size of each attention head in the value tensor. - attn_mask_type: str, default = `no_mask` - Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, - `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} - window_size: Tuple[int, int], default = None - Sliding window attention size. - alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` - Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. - core_attention_bias_type: str, default = `no_bias` - Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. - core_attention_bias_shape: str, default = `1hss` - Attention bias shape, {`1hss`, `b1ss`, `bhss`}. - core_attention_bias_requires_grad: bool, default = `True` - Whether attention bias requires gradient. - pad_between_seqs: bool, default = `False` - Whether there is padding between sequences in a batch. - This only applies to `qkv_format=thd`. - attention_dropout: float, default = 0.0 - Attention dropout. - context_parallel: bool, default = `False` - Whether context parallelism is used or not. - deterministic: bool, default = `False` - Whether to run `DotProductAttention` with determinism or not. - is_training: bool, default = `True` - Whether in training mode (`True`) or inference mode (`False`) - fp8: bool, default = `False` - Whether `DotProductAttention` is in an `fp8_autocast` region. - fp8_meta: Optional[Dict[str Any]], default = `None` - The FP8 metadata tensor of `DotProductAttention`. - """ - - qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor - qkv_dtype: torch.dtype = torch.bfloat16 - qkv_layout: str = "sbh3d" - batch_size: int = 1 - num_heads: int = 16 - num_gqa_groups: int = 16 - max_seqlen_q: int = 128 - max_seqlen_kv: int = 128 - head_dim_qk: int = 64 - head_dim_v: int = 64 - attn_mask_type: str = "no_mask" - window_size: Union[Tuple[int, int], None] = None - alibi_slopes_shape: Union[torch.Size, List, None] = None - core_attention_bias_type: str = "no_bias" - core_attention_bias_shape: str = "1hss" - core_attention_bias_requires_grad: bool = True - pad_between_seqs: bool = False - attention_dropout: float = 0.0 - context_parallel: bool = False - deterministic: bool = False - is_training: bool = True - fp8: bool = False - fp8_meta: Union[Dict[str, Any], None] = None - - def __eq__(self, other): - """ - Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared, - since all other entries of fp8_meta are unused in get_attention_backend. - """ - if not isinstance(other, self.__class__): - return NotImplemented - for field in fields(self): - fname = field.name - sf = getattr(self, fname) - of = getattr(other, fname) - if fname != "fp8_meta": - if sf != of: - return False - elif sf.get("recipe", None) != of.get("recipe", None): - return False - return True - - _alibi_cache = { "_num_heads": None, "_alibi_slopes": None, @@ -348,8 +196,7 @@ def __eq__(self, other): "_alibi_bias_require_update": False, } - -__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] +__all__ = ["DotProductAttention", "MultiheadAttention"] def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor: @@ -357,1196 +204,6 @@ def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor: return tensor.contiguous() if tensor.stride(-1) != 1 else tensor -def get_attention_backend( - attention_params: AttentionParams = None, -): - """ - Select the appropriate attention backend/sub-backend based on user input and runtime environment. - - Parameters - ---------- - See `AttentionParams`. - - Returns - ---------- - use_flash_attention: bool - Whether the `FlashAttention` backend has been selected. - use_fused_attention: bool - Whether the `FusedAttention` backend has been selected. - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`. - use_unfused_attention: bool - Whether the `UnfusedDotProductAttention` backend has been selected. - available_backends: List[bool] - All available backends that could support the provided input. A list of Booleans - in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. - """ - qkv_type = attention_params.qkv_type - qkv_dtype = attention_params.qkv_dtype - qkv_layout = attention_params.qkv_layout - batch_size = attention_params.batch_size - num_heads = attention_params.num_heads - num_gqa_groups = attention_params.num_gqa_groups - max_seqlen_q = attention_params.max_seqlen_q - max_seqlen_kv = attention_params.max_seqlen_kv - head_dim_qk = attention_params.head_dim_qk - head_dim_v = attention_params.head_dim_v - attn_mask_type = attention_params.attn_mask_type - window_size = attention_params.window_size - alibi_slopes_shape = attention_params.alibi_slopes_shape - core_attention_bias_type = attention_params.core_attention_bias_type - core_attention_bias_shape = attention_params.core_attention_bias_shape - core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad - pad_between_seqs = attention_params.pad_between_seqs - attention_dropout = attention_params.attention_dropout - context_parallel = attention_params.context_parallel - deterministic = attention_params.deterministic - is_training = attention_params.is_training - fp8 = attention_params.fp8 - fp8_meta = attention_params.fp8_meta - - # Run config - logger = logging.getLogger("DotProductAttention") - logger.setLevel(_log_level) - if not logger.hasHandlers(): - logger.addHandler(_stream_handler) - device_compute_capability = get_device_compute_capability() - cudnn_version = get_cudnn_version() - run_config = { - "transformer_engine_version": te.__version__, - "compute_capability": "sm" - + str(10 * device_compute_capability[0] + device_compute_capability[1]), - "flash_attn_version": ( - str(_flash_attn_version) if _flash_attn_is_installed else "not installed" - ), - "flash_attn_3_version": ( - str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed" - ), - "cudnn_version": ".".join([str(i) for i in cudnn_version]), - } - attention_params_dict = { - field.name: getattr(attention_params, field.name) for field in fields(attention_params) - } - run_config.update(attention_params_dict) - if fp8: - run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - logger.debug("Running with config=%s", run_config) - - # The following sections check if `FlashAttention` supports the provided attention params, - # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is - # necessary for performance/functionality, a warning will be issued to prompt users to - # install an appropriate FA version. - global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 - - # Filter: Environment variables - use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) - use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) - use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if not use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") - if not use_fused_attention: - logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") - if not use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") - - # Filter: Compute capability - if device_compute_capability < (8, 0): - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as it requires compute capability sm80+") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention as it requires compute capability sm80+") - use_fused_attention = False - if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") - _use_flash_attn_3 = False - - # Filter: Data type - if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ - torch.Tensor, - Float8Tensor, - ]: - if use_flash_attention and _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_flash_attention = False - if use_fused_attention: - logger.debug( - "Disabling FusedAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_fused_attention = False - - # Filter: Execution type - if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and not _use_flash_attn_3: - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") - use_flash_attention = False - if use_flash_attention and _use_flash_attn_3 and is_training: - logger.debug( - "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" - ) - use_flash_attention = False - if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") - use_unfused_attention = False - - # Filter: Head dimension - if use_flash_attention and head_dim_qk != head_dim_v: - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False - if use_flash_attention and ( - head_dim_qk > 256 - or head_dim_qk % 8 != 0 - or ( - head_dim_qk > 192 - and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) - ) - ): - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " - "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90/100+). " - "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", - head_dim_qk, - head_dim_v, - ".".join([str(i) for i in device_compute_capability]), - ) - use_flash_attention = False - qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") - if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": - logger.debug( - "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", - qkv_layout, - ) - use_fused_attention = False - - # Filter: QKV layout - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") - use_unfused_attention = False - if use_flash_attention and pad_between_seqs: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " - "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" - ) - use_flash_attention = False - - # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False - - # Filter: Context parallelism - # qkv_format | attn_mask_type | attn_bias_type | supported backends - # ---------------------------------------------------------------------------------------------------- - # bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention - # | no_mask, causal | | - # | cross-attention: | | - # | no_mask | | - # thd | self-attention: | no_bias | FlashAttention, FusedAttention - # | padding, padding_causal | | if no padding between sequences, - # | cross-attention: | | FusedAttention - # | padding | | if there is padding between sequences - # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v. - if context_parallel and use_unfused_attention: - logger.debug( - "Disabling UnfusedDotProductAttention as it does not support context parallelism" - ) - use_unfused_attention = False - if context_parallel and use_flash_attention: - if fp8 and fp8_meta["recipe"].fp8_dpa: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with FP8" - ) - use_flash_attention = False - if "bottom_right" in attn_mask_type: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " causal_bottom_right masking" - ) - use_flash_attention = False - elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " causal masking for cross-attention" - ) - use_flash_attention = False - elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with bias" - " type of %s", - core_attention_bias_type, - ) - use_flash_attention = False - elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " attention bias for THD format" - ) - use_flash_attention = False - - if context_parallel and use_fused_attention: - if "bottom_right" in attn_mask_type: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with" - " causal_bottom_right masking" - ) - use_fused_attention = False - elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with causal" - " masking for cross-attention" - ) - use_fused_attention = False - elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with bias type" - " of %s", - core_attention_bias_type, - ) - use_fused_attention = False - elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with attention" - " bias for THD format" - ) - use_fused_attention = False - elif head_dim_qk != head_dim_v: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with MLA" - ) - use_fused_attention = False - - # Filter: Attention mask - # attn_mask_type | attention_mask | supported backends - # ---------------------------------------------------------------------------------------- - # no_mask | None | All - # padding | | All - # self-attention | One tensor in shape [b, 1, 1, sq] | - # cross-attention | Tuple of two tensors in shapes | - # | [b, 1, 1, sq] and [b, 1, 1, skv] | - # causal | None | - # self-attention | | All - # cross-attention | | FusedAttention, UnfusedDotProductAttention - # padding_causal | Same as "padding" | - # self-attention | | All - # cross-attention | | FusedAttention, UnfusedDotProductAttention - # causal_bottom_right | None | All - # padding_causal_bottom_right | Same as "padding" | All - # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention - # | [b, h, sq, skv] | - if attn_mask_type == "arbitrary": - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention for arbitrary mask") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention for arbitrary mask") - use_fused_attention = False - if ( - use_flash_attention - and _use_flash_attn_3 - and attn_mask_type in ["causal", "padding_causal"] - and max_seqlen_q != max_seqlen_kv - ): - logger.warning( - "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - _use_flash_attn_3 = False - if ( - use_flash_attention - and attn_mask_type in ["causal", "padding_causal"] - and max_seqlen_q != max_seqlen_kv - ): - if _flash_attn_2_1_plus: - logger.warning( - "Disabling FlashAttention as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False - if not _flash_attn_is_installed: - _flash_attn_max_version = PkgVersion("2.1") - if ( - use_flash_attention - and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] - and max_seqlen_q != max_seqlen_kv - ): - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.1") - elif not _flash_attn_2_1_plus and not _use_flash_attn_3: - logger.warning( - "Disabling FlashAttention as it only supports top-left-diagonal " - "causal mask before flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False - if ( - use_flash_attention - and _use_flash_attn_3 - and fp8 - and fp8_meta["recipe"].fp8_dpa - and "padding" in attn_mask_type - ): - logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") - _use_flash_attn_3 = False - - # Filter: Sliding window attention - # backend | window_size | diagonal alignment - # --------------------------------------------------------------------------------- - # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; - # | | converts window_size to an 'arbitrary' mask - if window_size is None: - window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0: - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" - ) - use_fused_attention = False - elif max_seqlen_q > max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with s_q > s_kv for cross-attention" - ) - use_fused_attention = False - if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if _use_flash_attn_3: - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - _use_flash_attn_3 = False - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.3") - elif not _flash_attn_2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention = False - - # Filter: Attention bias - # backend | bias types | ALiBi diagonal alignment - # --------------------------------------------------------------------------------- - # FlashAttention | no_bias, alibi/alibi_slopes | bottom right - # FusedAttention | no_bias, post_scale_bias | - # | alibi/alibi_slopes | top left, - # | | bottom_right (converts to a 'post_scale_bias' bias) - # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | - # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias - if use_flash_attention and core_attention_bias_type == "alibi": - if _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for ALiBi") - _use_flash_attn_3 = False - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.4") - elif not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") - use_flash_attention = False - - if use_flash_attention and ( - core_attention_bias_type not in ["no_bias", "alibi"] - or core_attention_bias_shape is not None - ): - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention for pre/post_scale_bias") - use_flash_attention = False - - fu_core_attention_bias_type = core_attention_bias_type - fu_core_attention_bias_shape = core_attention_bias_shape - fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad - if ( - use_fused_attention - and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) - ): - fu_core_attention_bias_type = "post_scale_bias" - fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: - fu_core_attention_bias_shape = "1hss" - elif ( - len(alibi_slopes_shape) == 2 - and alibi_slopes_shape[0] == batch_size - and alibi_slopes_shape[1] == num_heads - ): - fu_core_attention_bias_shape = "bhss" - - if ( - use_fused_attention - and fu_core_attention_bias_type == "post_scale_bias" - and fu_core_attention_bias_shape != "1hss" - ): - if fu_core_attention_bias_requires_grad: - # remove this line when cuDNN adds bwd support for - # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] - logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") - use_fused_attention = False - else: - # max512 backend will only support [1, h, s, s] - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - - # Filter: cuDNN support - fused_attention_backend = None - if use_fused_attention: - q_type = TE_DType[qkv_dtype] - kv_type = q_type - if fp8 and fp8_meta["recipe"].fp8_dpa: - q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - kv_type = q_type - fused_attention_backend = tex.get_fused_attn_backend( - q_type, - kv_type, - QKVLayout[qkv_layout], - AttnBiasType[fu_core_attention_bias_type], - AttnMaskType[attn_mask_type], - attention_dropout, - num_heads, - num_gqa_groups, - max_seqlen_q, - max_seqlen_kv, - head_dim_qk, - head_dim_v, - window_size[0], - window_size[1], - ) - if fused_attention_backend == FusedAttnBackend["No_Backend"]: - logger.debug("Disabling FusedAttention as no backend supports the provided input") - use_fused_attention = False - fused_attention_backend = None - if ( - use_fused_attention - and window_size is not None - and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "slidng window attention", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None - if ( - use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - and fu_core_attention_bias_type == "post_scale_bias" - and fu_core_attention_bias_shape != "1hss" - ): - logger.debug( - "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in" - " [1, H, S, S] shape" - ) - use_fused_attention = False - fused_attention_backend = None - - # Filter: Determinism - # backend | deterministic - # --------------------------------------------- - # FlashAttention | - # flash-attn >=2.0, <2.4.1 | no - # flash-attn >=2.4.1 | yes - # FusedAttention | - # sub-backend 0 | yes - # sub-backend 1 | workspace optimization path and sm90+: yes; - # | otherwise: no - # sub-backend 2 | no - # UnfusedDotProductAttention | yes - if use_flash_attention and deterministic: - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.4.1") - elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3: - logger.warning( - "Disabling FlashAttention as version <2.4.1 does not support deterministic " - "execution. To use FlashAttention with deterministic behavior, " - "please install flash-attn >= 2.4.1." - ) - use_flash_attention = False - if use_fused_attention and deterministic: - if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons") - use_fused_attention = False - if ( - fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - and is_training - and ( - device_compute_capability < (9, 0) - or core_attention_bias_requires_grad - or cudnn_version < (8, 9, 5) - ) - ): - logger.debug("Disabling FusedAttention for determinism reasons") - use_fused_attention = False - - # All available backends - available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] - - # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. - # When `FusedAttention` does not support the provided attention params, and `FlashAttention` - # does, we recommend users to install flash-attn if not installed already. - if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed: - logger.warning( - "flash-attn may provide important feature support or performance improvement." - " Please install flash-attn %s.", - _get_supported_versions( - _flash_attn_version_required, - _flash_attn_max_version, - ), - ) - if use_flash_attention and not _flash_attn_is_installed: - use_flash_attention = False - available_backends[0] = False - - logger.debug( - "Available backends = {FlashAttention=%s, FusedAttention=%s%s," - " UnfusedDotProductAttention=%s}", - bool(available_backends[0]), - bool(available_backends[1]), - ( - f" (sub-backend {int(fused_attention_backend)})" - if fused_attention_backend is not None - else "" - ), - bool(available_backends[2]), - ) - - # Select FusedAttention for performance - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - ): - if device_compute_capability >= (9, 0): - logger.debug( - "Disabling FlashAttention to give FusedAttention preference on Hopper+ " - "for performance reasons" - ) - use_flash_attention = False - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["FP8"] - and _use_flash_attn_3 - ): - logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " - "in FP8 execution" - ) - use_flash_attention = False - - # Selected backend - if use_flash_attention: - use_fused_attention = False - use_unfused_attention = False - elif use_fused_attention: - use_unfused_attention = False - selected_backend = "NoBackend" - if use_flash_attention: - selected_backend = "FlashAttention" - elif use_fused_attention: - selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" - elif use_unfused_attention: - selected_backend = "UnfusedDotProductAttention" - logger.debug("Selected backend = %s", selected_backend) - - global _attention_backends - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - - return ( - use_flash_attention, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - available_backends, - ) - - -class InferenceParams: # pylint: disable=too-few-public-methods - """ - Inference parameters that are passed to the main model in order - to efficiently calculate and store the context during inference. - - Parameters - ---------- - max_batch_size : int - maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. - """ - - def __init__(self, max_batch_size, max_sequence_length): - self.max_sequence_length = max_sequence_length - self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 - self.key_value_memory_dict = {} - - def swap_key_value_dict(self, batch_indices): - """ - Reorders the KV cache using the specified batch indices. - - Parameters - ---------- - batch_indices : List[int] - Sequence of indices to reorder along the batch dimensions of - the KV cache. Must have a length equal to the batch size. - """ - if len(self.key_value_memory_dict) == 0: - raise ValueError("should not swap when dict in empty") - - for layer_number, inference_memory in self.key_value_memory_dict.items(): - inference_key_memory, inference_value_memory = inference_memory - assert ( - len(batch_indices) == inference_key_memory.shape[1] - ) # make sure batch size is the same - new_inference_key_memory = inference_key_memory[:, batch_indices] - new_inference_value_memory = inference_value_memory[:, batch_indices] - self.key_value_memory_dict[layer_number] = ( - new_inference_key_memory, - new_inference_value_memory, - ) - - -@torch.no_grad() -def get_full_mask( - max_seqlen_q: int, - max_seqlen_kv: int, - attn_mask_type: str = "no_mask", - attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, - window_size: Tuple[int, int] = None, - attention_type: str = "self", - bottom_right_alignment: bool = True, -) -> torch.Tensor: - """ - Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, - `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends - on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: - - attn_mask_type output shape diagonal alignment - -------------------------------------------------------------------------------------------- - no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment - causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left - causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right - padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment - padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left - padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right - arbitrary same as attention_mask follow bottom_right_alignment - - .. note:: - - For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right - diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, - i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, - max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( - [[False, False, True, True], [False, False, False, False]], - [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] - shape and is,:: - - [[[False, False, False, True], - [False, False, False, True], - [ True, True, True, True], - [ True, True, True, True]], - [[False, True, True, True], - [False, True, True, True], - [False, True, True, True], - [False, True, True, True]]] - - Parameters - ---------- - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - attn_mask_type: str, default = `no_mask` - Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", - "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} - attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - default = `None` - Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention - for the requirements of `attention_mask` for different `attn_mask_type`s. - window_size: Tuple[int, int], default = `None` - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. - attention_type: str, default = "self" - Attention type, {"self", "cross"} - bottom_right_alignment: bool, default = `True` - Whether to align the diagonal of the sliding window attention to the bottom right (`True`) - or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly - specifies "causal" or "causal_bottom_right". - - Returns - ---------- - attn_mask_type: str - For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` - attention_mask: torch.Tensor - The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` - actual_seqlens_q: torch.Tensor - For padding masks, the actual sequence lengths for queries, in shape [batch_size]. - For other masks, `None`. - actual_seqlens_kv: Optional[torch.Tensor], default = `None` - For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. - For other masks, `None`. - """ - # perform basic checks - change_type = window_size is not None and ( - window_size[0] != -1 or window_size[1] not in [-1, 0] - ) - if window_size is None: - window_size = (-1, -1) - if "causal" in attn_mask_type: - window_size = (window_size[0], 0) - window_size = ( - max_seqlen_kv if window_size[0] == -1 else window_size[0], - max_seqlen_q if window_size[1] == -1 else window_size[1], - ) - - # apply padding mask - actual_seqlens_q = None - actual_seqlens_kv = None - if "padding" in attn_mask_type: - if attention_type == "self": - attention_mask = torch.logical_or( - attention_mask.squeeze(1).unsqueeze(3), attention_mask - ) - else: - attention_mask = torch.logical_or( - attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] - ) - m = attention_mask.logical_not() - actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) - actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) - - # apply SWA mask - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) - swa_left = None - swa_right = None - if attn_mask_type == "causal_bottom_right" or ( - attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment - ): - swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] - swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] - elif attn_mask_type in ["causal", "padding_causal"] or ( - attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment - ): - swa_left = mask - window_size[0] - swa_right = mask + window_size[1] - elif attn_mask_type == "padding_causal_bottom_right" or ( - attn_mask_type == "padding" and bottom_right_alignment - ): - batch_size = attention_mask.shape[0] - swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( - actual_seqlens_kv - actual_seqlens_q - window_size[0] - ).view(batch_size, 1, 1, 1) - swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( - actual_seqlens_kv - actual_seqlens_q + window_size[1] - ).view(batch_size, 1, 1, 1) - swa_mask = torch.logical_not( - torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) - ) - if attention_mask is not None: - attention_mask = torch.logical_or(swa_mask, attention_mask) - else: - attention_mask = swa_mask - - # change mask type - if change_type: - attn_mask_type = "arbitrary" - - return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv - - -@torch.no_grad() -def get_alibi( - num_heads: int, - max_seqlen_q: int, - max_seqlen_kv: int, - actual_seqlens_q: Optional[torch.Tensor] = None, - actual_seqlens_kv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - bias_dtype: Optional[torch.dtype] = None, - bottom_right_alignment: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Parameters - ---------- - num_heads: int - Number of heads. - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - actual_seqlens_q: Optional[torch.Tensor], default = `None` - Actual sequence lengths for queries, in shape [batch_size]. - actual_seqlens_kv: Optional[torch.Tensor], default = `None` - Actual sequence lengths for keys and values, in shape [batch_size]. - alibi_slopes: Optional[torch.Tensor], default = `None` - Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. - bias_dtype: Optional[torch.dtype], default = `None` - Dtype of the generated ALiBi bias. If None, use torch.float32. - bottom_right_alignment: bool, default = `True` - Whether to align the diagonal of the ALiBi bias to the bottom right corner of - the matrix (`True`) or top left (`False`). - - Returns - ---------- - alibi_slopes: torch.Tensor - ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. - alibi_bias: torch.Tensor - ALiBi bias in FP32 or `bias_dtype`. Its shape is - (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, - and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or - (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in - [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and - `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. - """ - global _alibi_cache - if _alibi_cache["_alibi_slopes_require_update"]: - if alibi_slopes is not None: - _alibi_cache["_alibi_slopes"] = alibi_slopes - else: - n = 2 ** math.floor(math.log2(num_heads)) - m_0 = 2.0 ** (-8.0 / n) - m = torch.pow(m_0, torch.arange(1, 1 + n)) - - if n < num_heads: - m_hat_0 = 2.0 ** (-4.0 / n) - m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) - m = torch.cat([m, m_hat]) - - _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda") - _alibi_cache["_num_heads"] = num_heads - _alibi_cache["_alibi_slopes_require_update"] = False - - if _alibi_cache["_alibi_bias_require_update"]: - assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!" - if _alibi_cache["_alibi_slopes"].dim() == 1: - slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) - elif _alibi_cache["_alibi_slopes"].dim() == 2: - slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - else: - raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") - - bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - if actual_seqlens_q is None and actual_seqlens_kv is None: - if bottom_right_alignment: - bias = bias + max_seqlen_kv - max_seqlen_q - elif actual_seqlens_q is not None and actual_seqlens_kv is not None: - batch_size = actual_seqlens_q.shape[0] - bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - if bottom_right_alignment: - bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - else: - assert ( - False - ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" - bias = bias.abs().mul(-1) - bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) - _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv - _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment - bias_dtype = torch.float32 if bias_dtype is None else bias_dtype - _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") - _alibi_cache["_alibi_bias_require_update"] = False - - return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"] - - -def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: - """ - Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 - tensor of shape [batch_size + 1] containing the cumulative sequence lengths of - the samples in a batch. - """ - mask = mask.squeeze(1).squeeze(1) - reduced_mask = mask.logical_not().sum(dim=1) - cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") - cu_seqlens = torch.cat((zero, cu_seqlens)) - - return cu_seqlens - - -def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 - tensor of shape [batch_size + 1] containing the cumulative sequence lengths of - the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1] - containing the indices for the valid tokens. - """ - mask = mask.squeeze(1).squeeze(1) - bs, seqlen = mask.shape - - reduced_mask = mask.logical_not().sum(dim=1) - cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") - cu_seqlens = torch.cat((zero, cu_seqlens)) - - mask = mask.reshape(-1) - indices = mask.logical_not().nonzero() - indices = indices.unsqueeze(-1) - - num_nonzeros = indices.shape[0] - pad_amount = bs * seqlen - num_nonzeros - indices = F.pad( - input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen) - ) - - return cu_seqlens, indices - - -def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: - """ - Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32 - tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for - the valid tokens in a batch. - """ - bs = len(cu_seqlens) - 1 - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] - indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") - - num_nonzeros = indices.shape[0] - pad_amount = bs * max_seqlen - num_nonzeros - indices = F.pad( - input=indices, - pad=(0, 0, 0, 0, 0, pad_amount), - mode="constant", - value=float(bs * max_seqlen), - ) - - return indices - - -_cu_seqlens_cache = {} - - -def _get_full_cu_seqlens( - batch_size: int, - max_seqlen: int, - device: torch.device, -) -> torch.Tensor: - """Cumulative sequence lengths in full data batch - - All sequences in batch have the maximum sequence length. - - """ - global _cu_seqlens_cache - if (batch_size, max_seqlen) not in _cu_seqlens_cache: - _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( - 0, - (batch_size + 1) * max_seqlen, - step=max_seqlen, - dtype=torch.int32, - device=device, - ) - return _cu_seqlens_cache[(batch_size, max_seqlen)] - - -@jit_fuser -def pack_tensor( - indices: torch.Tensor, - tensor: torch.Tensor, -) -> torch.Tensor: - """ - Packs the given tensor using the `indices`. - """ - padding_indice = torch.zeros( - 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device - ) - indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) - if isinstance(tensor, Float8Tensor): - tensor_data = torch.cat((tensor._data, padding_indice), dim=0) - gathered_data = torch.gather(tensor_data, 0, indices) - - packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape) - else: - tensor = torch.cat((tensor, padding_indice), dim=0) - - packed = torch.gather(tensor, 0, indices) - return packed - - -@jit_fuser -def pack_2_tensors( - indices: torch.Tensor, - t1: torch.Tensor, - t2: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Packs the given 2 tensors using the `indices`. - """ - t1_packed = pack_tensor(indices, t1) - t2_packed = pack_tensor(indices, t2) - return t1_packed, t2_packed - - -@jit_fuser -def pack_3_tensors( - indices: torch.Tensor, - t1: torch.Tensor, - t2: torch.Tensor, - t3: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Packs the given 3 tensors using the `indices`. - """ - t1_packed = pack_tensor(indices, t1) - t2_packed = pack_tensor(indices, t2) - t3_packed = pack_tensor(indices, t3) - return t1_packed, t2_packed, t3_packed - - -@jit_fuser -def unpack_tensor( - indices: torch.Tensor, - dim0: int, - tensor: torch.Tensor, -) -> torch.Tensor: - """ - Inverse of `pack_tensor`. - """ - indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) - unpacked = torch.zeros( - dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device - ) - if isinstance(tensor, Float8Tensor): - unpacked.scatter_(0, indices, tensor._data) - unpacked_data = unpacked[0:-1, :, :] - unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape) - else: - unpacked.scatter_(0, indices, tensor) - unpacked = unpacked[0:-1, :, :] - return unpacked - - -@jit_fuser -def unpack_2_tensors( - indices: torch.Tensor, - dim0: int, - t1: torch.Tensor, - t2: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Inverse of `pack_2_tensors`. - """ - t1_unpacked = unpack_tensor(indices, dim0, t1) - t2_unpacked = unpack_tensor(indices, dim0, t2) - return t1_unpacked, t2_unpacked - - -@jit_fuser -def unpack_3_tensors( - indices: torch.Tensor, - dim0: int, - t1: torch.Tensor, - t2: torch.Tensor, - t3: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Inverse of `pack_3_tensors`. - """ - t1_unpacked = unpack_tensor(indices, dim0, t1) - t2_unpacked = unpack_tensor(indices, dim0, t2) - t3_unpacked = unpack_tensor(indices, dim0, t3) - return t1_unpacked, t2_unpacked, t3_unpacked - - -class PackTensors(torch.autograd.Function): - """ - Autograd function to pack tensors. - """ - - @staticmethod - def forward( - ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: - # pylint: disable=missing-function-docstring - assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." - ctx.save_for_backward(indices) - ctx.dim0 = tensors[0].shape[0] - if len(tensors) == 1: - return pack_tensor(indices, *tensors) - if len(tensors) == 2: - return pack_2_tensors(indices, *tensors) - return pack_3_tensors(indices, *tensors) - - @staticmethod - def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): - # pylint: disable=missing-function-docstring - (indices,) = ctx.saved_tensors - if len(grad_outputs) == 1: - return None, unpack_tensor(indices, ctx.dim0, *grad_outputs) - if len(grad_outputs) == 2: - return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs) - return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs) - - -class UnpackTensor(torch.autograd.Function): - """ - Autograd function to unpack a tensor. - """ - - @staticmethod - def forward( - ctx, - indices: torch.Tensor, - dim0: int, - tensor: torch.Tensor, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - ctx.save_for_backward(indices) - return unpack_tensor(indices, dim0, tensor) - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - (indices,) = ctx.saved_tensors - return None, None, pack_tensor(indices, grad_output) - - def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -1811,49 +468,6 @@ def flash_attn_a2a_communicate( return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs -def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): - """Get the list of quantizers used in attention from the quantizers list.""" - if not fp8: - num_of_nones = 8 if cp_specific_quantizers else 6 - return [None] * num_of_nones - QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True - QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) - S_quantizer = quantizers["scaling_fwd"][META_S] - S_quantizer.internal = True - S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer = quantizers["scaling_bwd"][META_DO] - dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True - dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) - dP_quantizer.interal = True - dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] - dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_CP_quantizer.internal = True - O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] - O_CP_quantizer.set_usage(rowwise=True, columnwise=False) - - if cp_specific_quantizers: - return ( - QKV_quantizer, - O_quantizer, - O_CP_quantizer, - S_quantizer, - dQKV_quantizer, - dQKV_CP_quantizer, - dO_quantizer, - dP_quantizer, - ) - - return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer - - _cu_seqlens_info_with_cp_cache = {} @@ -1988,7 +602,7 @@ def forward( dQKV_CP_quantizer, dO_quantizer, dP_quantizer, - ) = get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) if fp8: if use_fused_attention: @@ -2071,12 +685,12 @@ def forward( if use_fused_attention: softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) else: - softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3 + softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or fa_utils.use_v3 flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: + if fa_utils.use_v3: if qkv_format == "thd": flash_attn_fwd = _flash_attn_varlen_fwd_v3 else: @@ -2089,16 +703,16 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3: + if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus) or fa_utils.use_v3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = 0 if causal else -1 - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus and qkv_format == "thd": + if fa_utils.v2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_0_plus: + if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 # Flash Attn inputs @@ -2266,15 +880,15 @@ def forward( causal=True, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[3] elif i <= rank: if pad_between_seqs: @@ -2380,11 +994,11 @@ def forward( max_seqlen_q, max_seqlen_kv // 2, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + if fa_utils.use_v3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( @@ -2403,15 +1017,15 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[3] else: if pad_between_seqs: @@ -2526,11 +1140,11 @@ def forward( max_seqlen_q // 2, max_seqlen_kv, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + if fa_utils.use_v3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus ): fa_forward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( @@ -2549,15 +1163,15 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[3] else: if pad_between_seqs: @@ -2664,15 +1278,15 @@ def forward( causal=False, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[3] if i > 0: @@ -3036,7 +1650,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: + if fa_utils.use_v3: if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_varlen_bwd_v3 else: @@ -3048,11 +1662,11 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_0_plus: + if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): @@ -3186,14 +1800,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus - ): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, 0) - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 - if not _use_flash_attn_3: + if not fa_utils.use_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3303,14 +1915,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus - ): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - if _flash_attn_2_7_0_plus: + if fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not _use_flash_attn_3: + if not fa_utils.use_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3422,14 +2032,12 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus - ): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not _use_flash_attn_3: + if not fa_utils.use_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, @@ -3518,12 +2126,12 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 - if not _use_flash_attn_3: + if not fa_utils.use_v3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout, @@ -3851,13 +2459,13 @@ def forward( assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( - use_fused_attention or _flash_attn_2_3_plus + use_fused_attention or fa_utils.v2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: + if fa_utils.use_v3: if qkv_format == "thd": flash_attn_fwd = _flash_attn_varlen_fwd_v3 else: @@ -3869,11 +2477,11 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus and qkv_format == "thd": + if fa_utils.v2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_0_plus: + if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 assert qkv_format != "thd", f"{qkv_format} format is not supported!" @@ -3947,7 +2555,7 @@ def forward( ) max_seqlen_kv_ = seq_end_idx - seq_start_idx if use_fused_attention or qkv_format == "thd": - cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( k.shape[1], max_seqlen_kv_, k.device ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] @@ -3984,11 +2592,9 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] - if _use_flash_attn_3 or ( - _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus - ): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( @@ -3999,15 +2605,15 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] - if not _use_flash_attn_3: + if not fa_utils.use_v3: rng_states[i] = fa_outputs[3] if i > 0: @@ -4107,7 +2713,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: + if fa_utils.use_v3: if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_varlen_bwd_v3 else: @@ -4119,11 +2725,11 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_0_plus: + if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): @@ -4180,11 +2786,11 @@ def backward(ctx, dout): ctx.max_seqlen_q, max_seqlen_kv, ] - if not _use_flash_attn_3: + if not fa_utils.use_v3: fa_backward_kwargs["rng_state"] = rng_states[i] - if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] - if _flash_attn_2_7_0_plus: + if fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( @@ -4318,13 +2924,13 @@ def forward( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention - or _flash_attn_2_3_plus + or fa_utils.v2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: + if fa_utils.use_v3: if qkv_format == "thd": flash_attn_fwd = _flash_attn_varlen_fwd_v3 else: @@ -4337,16 +2943,16 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_forward_kwargs["window_size"] = window_size - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size[0] fa_forward_kwargs["window_size_right"] = window_size[1] - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus and qkv_format == "thd": + if fa_utils.v2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None - if _flash_attn_2_6_0_plus: + if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 assert ( @@ -4368,7 +2974,7 @@ def forward( is_output_fp8 = False QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) ) if fp8: if use_fused_attention: @@ -4458,12 +3064,12 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not _flash_attn_2_7_0_plus: + if not fa_utils.v2_7_0_plus: out, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + rng_state = fa_outputs[7] if not fa_utils.use_v3 else None else: out, softmax_lse = fa_outputs[0], fa_outputs[1] - rng_state = fa_outputs[3] if not _use_flash_attn_3 else None + rng_state = fa_outputs[3] if not fa_utils.use_v3 else None aux_ctx_tensors = [softmax_lse, rng_state] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) @@ -4634,7 +3240,7 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: + if fa_utils.use_v3: if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_varlen_bwd_v3 else: @@ -4647,16 +3253,16 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): + if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = ctx.window_size - elif _flash_attn_2_7_0_plus: + elif fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = ctx.window_size[0] fa_backward_kwargs["window_size_right"] = ctx.window_size[1] - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic - if _flash_attn_2_6_0_plus: + if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: @@ -4723,7 +3329,7 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if not _use_flash_attn_3: + if not fa_utils.use_v3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( dout, @@ -4904,221 +3510,6 @@ def attn_forward_func_with_cp( return out -class RotaryPositionEmbedding(torch.nn.Module): - """ - Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. - """ - - def __init__( - self, - dim: int, - rotary_percent: float = 1.0, - seq_len_interpolation_factor: Optional[int] = None, - pretrained_max_position_embeddings: Optional[int] = None, - rotary_base: float = 10000.0, - ): - """ - Parameters - ---------- - dim: int - rotary embedding dimension - rotary_percent: float - Percent of rotary dimension to use for rotary position embeddings. - seq_len_interpolation_factor: int - if not None, discrete positions will be interpolated by this factor via the trick in - https://arxiv.org/abs/2306.15595 - pretrained_max_position_embeddings: int - pre-trained max_position_embeddings before position interpolation - """ - super().__init__() - if rotary_percent < 1.0: - dim = int(dim * rotary_percent) - self.seq_len_interpolation_factor = seq_len_interpolation_factor - self.rotary_base = rotary_base - inv_freq = 1.0 / ( - self.rotary_base - ** ( - torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) - / dim - ) - ) - self.register_buffer("inv_freq", inv_freq) - self.pretrained_max_position_embeddings = pretrained_max_position_embeddings - - def forward(self, max_seq_len: int, offset: int = 0): - """ - Create rotary position embedding frequencies - - Parameters - ---------- - max_seq_len: int - sequence length of a sample - offset: int, default = 0 - fixed offset for freqencies - """ - seq = ( - torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - + offset - ) - - if ( - self.pretrained_max_position_embeddings is not None - and self.seq_len_interpolation_factor is not None - ): - if ( - max_seq_len - > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor - ): - # dynamic linear scaling (length > position we have learned) - seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) - else: - # fixed linear scaling - seq *= 1 / self.seq_len_interpolation_factor - - freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) - # first part even vector components, second part odd vector components, - # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) - # emb [seq_length, .., dim] - return emb.reshape(emb.size(0), 1, 1, emb.size(1)) - - -class FusedRoPEFunc(torch.autograd.Function): - """ - Function for FusedRoPE - - This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and - the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid - the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. - """ - - @staticmethod - def forward( - ctx, - t: torch.Tensor, - freqs: torch.Tensor, - tensor_format: str = "sbhd", - cu_seqlens: Union[torch.Tensor, None] = None, - cp_size: int = 1, - cp_rank: int = 0, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - if freqs.dtype != torch.float32: - freqs = freqs.float() - if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) - elif tensor_format == "bshd": - output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) - elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) - else: - raise ValueError(f"Unsupported tensor_format: {tensor_format}.") - ctx.save_for_backward(freqs, cu_seqlens) - ctx.tensor_format = tensor_format - ctx.cp_size = cp_size - ctx.cp_rank = cp_rank - - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - freqs, cu_seqlens = ctx.saved_tensors - if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) - elif ctx.tensor_format == "bshd": - grad_input = tex.fused_rope_backward( - grad_output.transpose(0, 1), freqs, True - ).transpose(0, 1) - elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward( - grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank - ) - else: - raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") - - return grad_input, None, None, None, None, None - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """ - change sign so the last dimension becomes [-odd, +even] - """ - x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, - tensor_format: str = "sbhd", - fused: bool = False, - cu_seqlens: Union[torch.Tensor, None] = None, - cp_size: int = 1, - cp_rank: int = 0, -) -> torch.Tensor: - """ - Apply rotary positional embedding tensor to the input tensor. - - Parameters - ---------- - t: torch.Tensor - Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which - rotary positional embedding will be applied. - freqs: torch.Tensor - Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', - with `s2 >= s` and `d2 <= d`. - fused: bool, default = False - Whether to use a fused applying RoPE implementation. - tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' - is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is - of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. - cu_seqlens: torch.Tensor, default = None. - Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and - dtype torch.int32. Only valid when `tensor_format` is 'thd'. - Should be `cu_seqlens_padded` when cp_size > 1. - cp_size: int, default = 1. - Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. - cp_rank: int, default = 0. - Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. - """ - if fused: - assert ( - tensor_format != "thd" or cu_seqlens is not None - ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) - - assert tensor_format in ("sbhd", "bshd"), ( - "Only formats `sbhd` or `bshd` are supported for input tensor `t` " - f"when fused is False, got {tensor_format}." - ) - - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # Only apply the rotary embeddings up to the sequence length of the running - # input. - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - if tensor_format == "bshd": - freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] - # cos/sin first then dtype conversion for better precision - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - - rot_dim = freqs.shape[-1] - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) - - class _SplitAlongDim(torch.autograd.Function): """""" @@ -5320,13 +3711,15 @@ def forward( key_layer.shape[0], ) - attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( - max_seqlen_q, - max_seqlen_kv, - attn_mask_type=attn_mask_type, - attention_mask=attention_mask, - window_size=window_size, - attention_type=self.attention_type, + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( + dpa_utils.get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + ) ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] @@ -5392,7 +3785,8 @@ def forward( if core_attention_bias_type == "post_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" if core_attention_bias_type == "alibi": - _, core_attention_bias = get_alibi( + _, core_attention_bias = dpa_utils.get_alibi( + _alibi_cache, output_size[1], output_size[2], output_size[3], @@ -5501,202 +3895,6 @@ def backward( return dq, dk, dv -def get_qkv_layout( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qkv_format: str = "sbhd", -) -> str: - """Get qkv layout. - - Parameters - ---------- - q: torch.Tensor - Query tensor. - k: torch.Tensor - Key tensor. - v: torch.Tensor - Value tensor. - qkv_format: str, default = `sbhd` - Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for - the sequence length dimension, `b` batch size, `h` the number of attention heads, - `d` head size, and `t` the total number of tokens in a batch, i.e. - `t = sum(s_i) for i = 0...b-1`. - - Returns - ---------- - qkv_layout: str - Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five - memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk - of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means - `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v` - are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and - `v = kv[:,:,:,1,:]`. - Mapping: - `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} - `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} - `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} - q: torch.Tensor - Query tensor. It may be different from input `q` as we try to fit tensors to - a supported layout. - k: torch.Tensor - Key tensor. It may be different from input `k` as we try to fit tensors to - a supported layout. - v: torch.Tensor - Value tensor. It may be different from input `v` as we try to fit tensors to - a supported layout. - """ - - check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) - assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" - - def run_iteratively(q, k, v): - # check data pointers - data_ptr = q.untyped_storage().data_ptr() - check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) - check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) - data_ptr = k.untyped_storage().data_ptr() - check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) - - # check tensor shapes - shape = q.shape - check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) - shape = k.shape - check_shapes_kv = shape[:-1] == v.shape[:-1] - - # check tensor strides - stride = q.stride() - check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) - check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( - sv / v.shape[-1] for sv in v.stride()[:-1] - ) - - # check tensor offsets for h3d and 3hd layouts - prod_h_d = q.shape[-1] * q.shape[-2] - check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v])) - check_h3d_offsets = all( - x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v]) - ) - - # check tensor offsets for hd_h2d and hd_2hd layouts - prod_all_dims = [np.prod(x.shape) for x in [q, k]] - offset = prod_all_dims[0] if check_ptrs_qkv else 0 - prod_h_d = k.shape[-1] * k.shape[-2] - check_2hd_offsets = all( - x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v]) - ) - check_h2d_offsets = all( - x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v]) - ) - - # check tensor offsets for hd_hd_hd layouts - check_hd_offsets_qkv = ( - all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v])) - if check_ptrs_qkv - else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v])) - ) - check_hd_offsets_qk = ( - all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k])) - if not check_ptrs_qkv and check_ptrs_qk - else all(x.storage_offset() == 0 for i, x in enumerate([q, k])) - ) - check_hd_offsets_kv = ( - all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v])) - if not check_ptrs_qkv and check_ptrs_kv - else all(x.storage_offset() == 0 for i, x in enumerate([k, v])) - ) - - if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets: - # sb3hd, bs3hd, t3hd - # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv - qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] - elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets: - # sbh3d, bsh3d, th3d - # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv - qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] - elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets: - # sbhd_sb2hd, bshd_bs2hd, thd_t2hd - # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv - # q and kv may be disjoint or consecutive in memory, and when consecutive, they may - # have the same data pointer, i.e. check_ptrs_qkv=True - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] - elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets: - # sbhd_sbh2d, bshd_bsh2d, thd_th2d - # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv - # q and kv may be disjoint or consecutive in memory, and when consecutive, they may - # have the same data pointer, i.e. check_ptrs_qkv=True - qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] - elif ( - check_strides_kv - and check_shapes_kv - and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) - ): - # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd - # three chunks of memory, q, k and v, which may be disjoint or consecutive, and - # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or - # check_ptrs_qk=True or check_ptrs_kv=True - qkv_layout = "_".join(list([qkv_format]) * 3) - else: - qkv_layout = "not_supported" - - return qkv_layout - - qkv_layout = run_iteratively(q, k, v) - if qkv_layout == "not_supported": - # force q,k,v to be contiguous and run get_layout again - q, k, v = [x.contiguous() for x in [q, k, v]] - qkv_layout = run_iteratively(q, k, v) - if qkv_layout == "not_supported": - raise RuntimeError("The provided qkv memory layout is not supported!") - - return qkv_layout, q, k, v - - -def check_set_window_size( - attn_mask_type: str, - window_size: Tuple[int, int] = None, -): - """Check if sliding window size is compliant with attention mask type. - If not, set it to the appropriate size. - - attn_mask_type | window_size - ------------------------------------------------------------------------- - no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0) - causal, padding_causal | (-1, 0) or (>=0, 0) - causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) - """ - orig_window_size = window_size - if "causal" in attn_mask_type: - if orig_window_size is None: - window_size = (-1, 0) - elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] != 0 - ): - window_size = (orig_window_size[0], 0) - warnings.warn( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): - assert False, ( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: - if orig_window_size is None: - window_size = (-1, -1) - elif orig_window_size == (-1, 0): - window_size = (-1, -1) - warnings.warn( - "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): - assert False, ( - "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type - ) - else: - assert False, "Invalid attn_mask_type: " + attn_mask_type - return window_size - - class FlashAttention(torch.nn.Module): """Dot product attention, using HazyResearch flash-attn package: https://github.com/Dao-AILab/flash-attention @@ -5713,13 +3911,13 @@ def __init__( ) -> None: super().__init__() - if _flash_attn_is_installed: + if fa_utils.is_installed: assert ( - _flash_attn_version >= _flash_attn_version_required - ), f"FlashAttention minimum version {_flash_attn_version_required} is required." + fa_utils.version >= fa_utils.version_required + ), f"FlashAttention minimum version {fa_utils.version_required} is required." assert ( - _flash_attn_version <= _flash_attn_max_version - ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." + fa_utils.version <= fa_utils.max_version + ), f"FlashAttention maximum version {fa_utils.max_version} is supported." self.softmax_scale = softmax_scale self.attention_dropout_ctx = attention_dropout_ctx @@ -5728,9 +3926,9 @@ def __init__( self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic self.logger = logging.getLogger("FlashAttention") - self.logger.setLevel(_log_level) + self.logger.setLevel(attn_log._log_level) if not self.logger.hasHandlers(): - self.logger.addHandler(_stream_handler) + self.logger.addHandler(attn_log._stream_handler) def forward( self, @@ -5834,11 +4032,13 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask) + cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( + attention_mask + ) else: - indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) cu_seqlens_kv = cu_seqlens_q - query_layer, key_layer, value_layer = PackTensors.apply( + query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply( indices_q, query_layer, key_layer, value_layer ) else: @@ -5846,23 +4046,29 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) - cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) + cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( + attention_mask[0] + ) + cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices( + attention_mask[1] + ) else: - indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) - query_layer = PackTensors.apply(indices_q, query_layer) - key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) + indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) + indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv) + query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer) + key_layer, value_layer = dpa_utils.PackTensors.apply( + indices_kv, key_layer, value_layer + ) else: # Cumulative sequence lengths for unpadded data if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_q, query_layer.device, ) if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_kv, key_layer.device, @@ -5921,28 +4127,26 @@ def forward( with self.attention_dropout_ctx(): fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: + if fa_utils.v2_3_plus: fa_optional_forward_kwargs["window_size"] = window_size - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_args_thd = [] if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: - func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 + func = flash_attn_func if not fa_utils.use_v3 else flash_attn_func_v3 else: - if _flash_attn_2_5_7_plus: + if fa_utils.v2_5_7_plus: fa_optional_forward_kwargs["block_table"] = None func = ( - flash_attn_varlen_func - if not _use_flash_attn_3 - else flash_attn_varlen_func_v3 + flash_attn_varlen_func if not fa_utils.use_v3 else flash_attn_varlen_func_v3 ) fa_optional_forward_args_thd.append(cu_seqlens_q) fa_optional_forward_args_thd.append(cu_seqlens_kv) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) - if _use_flash_attn_3: + if fa_utils.use_v3: fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["deterministic"] = self.deterministic @@ -5994,12 +4198,12 @@ def convert_to_torch_float8(tensor, dtype): **fa_3_optional_forward_kwargs, ) except TypeError as e: - if _flash_attn_3_0_0_beta: + if fa_utils.v3_0_0_beta: e.args = ( e.args[0] + ". Please update your flash-attn v3 (beta) installation as it " + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, + + fa_utils.v3_installation_steps, ) + e.args[1:] raise @@ -6021,7 +4225,7 @@ def convert_to_torch_float8(tensor, dtype): ) if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: - output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) @@ -6122,7 +4326,7 @@ def forward( fake_dtype = q.dtype QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) ) if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -6679,20 +4883,20 @@ def forward( "Please provide attention_mask or cu_seqlens for padding!" ) if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) else: if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_q, query_layer.device, ) if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_kv, key_layer.device, @@ -6953,15 +5157,15 @@ def __init__( super().__init__() self.logger = logging.getLogger("DotProductAttention") - self.logger.setLevel(_log_level) + self.logger.setLevel(attn_log._log_level) if not self.logger.hasHandlers(): - self.logger.addHandler(_stream_handler) + self.logger.addHandler(attn_log._stream_handler) self.qkv_format = qkv_format attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type - self.window_size = check_set_window_size(attn_mask_type, window_size) + self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -7418,7 +5622,7 @@ def forward( if window_size is None: window_size = self.window_size - window_size = check_set_window_size(attn_mask_type, window_size) + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -7552,18 +5756,18 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) else: - cu_seqlens_q = _get_full_cu_seqlens( + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_q, query_layer.device, ) - cu_seqlens_kv = _get_full_cu_seqlens( + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_kv, key_layer.device, @@ -7574,11 +5778,13 @@ def forward( and isinstance(key_layer, Float8Tensor) and isinstance(value_layer, Float8Tensor) ): - qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout( - query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format + qkv_layout, query_layer._data, key_layer._data, value_layer._data = ( + dpa_utils.get_qkv_layout( + query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format + ) ) else: - qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout( + qkv_layout, query_layer, key_layer, value_layer = dpa_utils.get_qkv_layout( query_layer, key_layer, value_layer, qkv_format=qkv_format ) @@ -7640,7 +5846,7 @@ def forward( else: pad_between_seqs = False - attention_params = AttentionParams( + attention_params = dpa_utils.AttentionParams( qkv_type=type(query_layer), qkv_dtype=query_layer.dtype, qkv_layout=qkv_layout, @@ -7667,7 +5873,7 @@ def forward( fp8=self.fp8, fp8_meta=self.fp8_meta, ) - global _attention_backends, _use_flash_attn_3 + global _attention_backends if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -7675,18 +5881,25 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: - _use_flash_attn_3 = _flash_attn_3_is_installed + fa_utils.use_v3 = fa_utils.v3_is_installed ( use_flash_attention, use_fused_attention, fused_attention_backend, use_unfused_attention, _, - ) = get_attention_backend(attention_params) + ) = dpa_utils.get_attention_backend(attention_params) + # Set global _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False if use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", - _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version, + fa_utils.version if not fa_utils.use_v3 else fa_utils.fa3_version, ) elif use_fused_attention: self.logger.info( @@ -7703,7 +5916,8 @@ def forward( if use_flash_attention: if core_attention_bias_type == "alibi": - alibi_slopes, _ = get_alibi( + alibi_slopes, _ = dpa_utils.get_alibi( + _alibi_cache, query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, @@ -7738,7 +5952,8 @@ def forward( alibi_slopes is not None or max_seqlen_q != max_seqlen_kv ): fu_core_attention_bias_type = "post_scale_bias" - _, fu_core_attention_bias = get_alibi( + _, fu_core_attention_bias = dpa_utils.get_alibi( + _alibi_cache, query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, @@ -8025,7 +6240,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type - self.window_size = check_set_window_size(attn_mask_type, window_size) + self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) self.layer_number = layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -8385,7 +6600,7 @@ def forward( attn_mask_type = self.attn_mask_type if window_size is None: window_size = self.window_size - window_size = check_set_window_size(attn_mask_type, window_size) + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: diff --git a/transformer_engine/pytorch/dot_product_attention/__init__.py b/transformer_engine/pytorch/dot_product_attention/__init__.py new file mode 100644 index 0000000000..6a4c84f47d --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Python interface for dot product attention""" diff --git a/transformer_engine/pytorch/dot_product_attention/inference.py b/transformer_engine/pytorch/dot_product_attention/inference.py new file mode 100644 index 0000000000..6371bdab57 --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/inference.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Inference classes for attention +""" + + +class InferenceParams: # pylint: disable=too-few-public-methods + """ + Inference parameters that are passed to the main model in order + to efficiently calculate and store the context during inference. + + Parameters + ---------- + max_batch_size : int + maximum batch size during inference. + max_sequence_length : int + maximum sequence length during inference. + """ + + def __init__(self, max_batch_size, max_sequence_length): + self.max_sequence_length = max_sequence_length + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_indices): + """ + Reorders the KV cache using the specified batch indices. + + Parameters + ---------- + batch_indices : List[int] + Sequence of indices to reorder along the batch dimensions of + the KV cache. Must have a length equal to the batch size. + """ + if len(self.key_value_memory_dict) == 0: + raise ValueError("should not swap when dict in empty") + + for layer_number, inference_memory in self.key_value_memory_dict.items(): + inference_key_memory, inference_value_memory = inference_memory + assert ( + len(batch_indices) == inference_key_memory.shape[1] + ) # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_indices] + new_inference_value_memory = inference_value_memory[:, batch_indices] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, + new_inference_value_memory, + ) diff --git a/transformer_engine/pytorch/dot_product_attention/rope.py b/transformer_engine/pytorch/dot_product_attention/rope.py new file mode 100644 index 0000000000..83698c7bc6 --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/rope.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Rotary Position Embedding implementation of different types along with helper functions +""" +from typing import Optional, Tuple, Union +import torch +import transformer_engine_torch as tex + + +class RotaryPositionEmbedding(torch.nn.Module): + """ + Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. + """ + + def __init__( + self, + dim: int, + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[int] = None, + pretrained_max_position_embeddings: Optional[int] = None, + rotary_base: float = 10000.0, + ): + """ + Parameters + ---------- + dim: int + rotary embedding dimension + rotary_percent: float + Percent of rotary dimension to use for rotary position embeddings. + seq_len_interpolation_factor: int + if not None, discrete positions will be interpolated by this factor via the trick in + https://arxiv.org/abs/2306.15595 + pretrained_max_position_embeddings: int + pre-trained max_position_embeddings before position interpolation + """ + super().__init__() + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.rotary_base = rotary_base + inv_freq = 1.0 / ( + self.rotary_base + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) + / dim + ) + ) + self.register_buffer("inv_freq", inv_freq) + self.pretrained_max_position_embeddings = pretrained_max_position_embeddings + + def forward(self, max_seq_len: int, offset: int = 0): + """ + Create rotary position embedding frequencies + + Parameters + ---------- + max_seq_len: int + sequence length of a sample + offset: int, default = 0 + fixed offset for freqencies + """ + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) + + if ( + self.pretrained_max_position_embeddings is not None + and self.seq_len_interpolation_factor is not None + ): + if ( + max_seq_len + > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + ): + # dynamic linear scaling (length > position we have learned) + seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) + else: + # fixed linear scaling + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + return emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + +class FusedRoPEFunc(torch.autograd.Function): + """ + Function for FusedRoPE + + This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and + the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid + the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ + + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + if freqs.dtype != torch.float32: + freqs = freqs.float() + if tensor_format == "sbhd": + output = tex.fused_rope_forward(t, freqs, False) + elif tensor_format == "bshd": + output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) + elif tensor_format == "thd": + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) + else: + raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + ctx.save_for_backward(freqs, cu_seqlens) + ctx.tensor_format = tensor_format + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring + freqs, cu_seqlens = ctx.saved_tensors + if ctx.tensor_format == "sbhd": + grad_input = tex.fused_rope_backward(grad_output, freqs, False) + elif ctx.tensor_format == "bshd": + grad_input = tex.fused_rope_backward( + grad_output.transpose(0, 1), freqs, True + ).transpose(0, 1) + elif ctx.tensor_format == "thd": + grad_input = tex.fused_rope_thd_backward( + grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank + ) + else: + raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") + + return grad_input, None, None, None, None, None + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + fused: bool = False, + cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which + rotary positional embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + fused: bool, default = False + Whether to use a fused applying RoPE implementation. + tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' + is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is + of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. + cu_seqlens: torch.Tensor, default = None. + Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and + dtype torch.int32. Only valid when `tensor_format` is 'thd'. + Should be `cu_seqlens_padded` when cp_size > 1. + cp_size: int, default = 1. + Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. + cp_rank: int, default = 0. + Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. + """ + if fused: + assert ( + tensor_format != "thd" or cu_seqlens is not None + ), "cu_seqlens must not be None when tensor_format is 'thd'." + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) + + assert tensor_format in ("sbhd", "bshd"), ( + "Only formats `sbhd` or `bshd` are supported for input tensor `t` " + f"when fused is False, got {tensor_format}." + ) + + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + freqs = freqs[:cur_seq_len] + if tensor_format == "bshd": + freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py new file mode 100644 index 0000000000..a4424d9d38 --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -0,0 +1,1639 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utils/Helper classes and methods for attention +""" +import math +import os +from typing import Any, Dict, List, Optional, Tuple, Union +import warnings +import logging +import functools + +from dataclasses import dataclass, fields +import numpy as np +from packaging.version import Version as PkgVersion + +import torch +import torch.nn.functional as F +import transformer_engine_torch as tex +import transformer_engine as te +from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + QKVLayout, + AttnBiasType, + AttnMaskType, + FusedAttnBackend, + META_QKV, + META_DQKV, + META_O, + META_DO, + META_S, + META_DP, + META_O_CP, + META_DQKV_CP, +) +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.constants import TE_DType + + +from transformer_engine.pytorch.utils import ( + get_device_compute_capability, + get_cudnn_version, +) + +from transformer_engine.pytorch.jit import jit_fuser + +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) +_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) + + +class AttentionLogging: + """ + Manage logging for attention module + """ + + _log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL + _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") + _stream_handler = logging.StreamHandler() + fa_logger = logging.getLogger(__name__) + + @staticmethod + def setup_logging(): + """ + Set up log levels, logger and handlers + """ + _log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + AttentionLogging._log_level = _log_levels[ + AttentionLogging._log_level if AttentionLogging._log_level in [0, 1, 2] else 2 + ] + AttentionLogging._stream_handler.setFormatter(AttentionLogging._formatter) + AttentionLogging.fa_logger.setLevel(AttentionLogging._log_level) + if not AttentionLogging.fa_logger.hasHandlers(): + AttentionLogging.fa_logger.addHandler(AttentionLogging._stream_handler) + + +@functools.lru_cache(maxsize=None) +def _get_supported_versions(version_min, version_max): + """ + Calculate version info based on min and max numbers + """ + return ">= " + str(version_min) + ", " + "<= " + str(version_max) + + +class FlashAttentionUtils: + """ + Manage Flash Attention versioning information + """ + + # Detect flash-attn v2 in the environment + is_installed = False + version = PkgVersion("0") + version_required = PkgVersion("2.1.1") + version_required_blackwell = PkgVersion("2.7.3") + max_version = PkgVersion("2.7.4.post1") + v2_plus = False + v2_1_plus = False + v2_3_plus = False + v2_4_plus = False + v2_4_1_plus = False + v2_5_7_plus = False + v2_6_0_plus = False + v2_7_0_plus = False + + v3_is_installed = False + fa3_version = PkgVersion("0") + v3_0_0_beta = False + use_v3 = False + # TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved + # https://github.com/Dao-AILab/flash-attention/issues/1452 + v3_installation_steps = """\ + (1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" + (2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` + (3) mkdir -p $python_path/flashattn_hopper + (4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" + + @staticmethod + def set_flash_attention_version(): + """ + Setup version info for FA v2.x + """ + FlashAttentionUtils.is_installed = True + FlashAttentionUtils.v2_plus = FlashAttentionUtils.version >= PkgVersion("2") + FlashAttentionUtils.v2_1_plus = FlashAttentionUtils.version >= PkgVersion("2.1") + FlashAttentionUtils.v2_3_plus = FlashAttentionUtils.version >= PkgVersion("2.3") + FlashAttentionUtils.v2_4_plus = FlashAttentionUtils.version >= PkgVersion("2.4") + FlashAttentionUtils.v2_4_1_plus = FlashAttentionUtils.version >= PkgVersion("2.4.1") + FlashAttentionUtils.v2_5_7_plus = FlashAttentionUtils.version >= PkgVersion("2.5.7") + FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0") + FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0") + + # Detect flash-attn v3 in the environment + # This section will be removed when FA3 is released as a regular FA package, + # i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0 + @staticmethod + def set_flash_attention_3_params(): + """ + Setup version info for FA v3.x + """ + FlashAttentionUtils.v3_is_installed = True + FlashAttentionUtils.v3_0_0_beta = ( + PkgVersion("3.0.0b") < FlashAttentionUtils.fa3_version < PkgVersion("3.0.0") + ) + FlashAttentionUtils.use_v3 = True + + +@dataclass(eq=True) +class AttentionParams: + """ + Attention parameters used to determine which backend to be used. + + Parameters + ---------- + qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` + Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. + qkv_dtype: torch.dtype, default = `torch.bfloat16` + Data type of query/key/value tensors. + qkv_layout: str, default = "sbh3d" + Query/key/value tensor memory layout. + batch_size: int, default = 1 + Batch size. + num_heads: int, default = 16 + Number of attention heads in the query tensor. + num_gqa_groups: int, default = 16 + Number of attention heads in key and value tensors. + max_seqlen_q: int, default = 128 + Maximum sequence length of the query tensor. + max_seqlen_kv: int, default = 128 + Maximum sequence length of the key and value tensors. + head_dim_qk: int, default = 64 + The size of each attention head in query and key tensors. + head_dim_v: int, default = 64 + The size of each attention head in the value tensor. + attn_mask_type: str, default = `no_mask` + Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, + `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} + window_size: Tuple[int, int], default = None + Sliding window attention size. + alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` + Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. + core_attention_bias_type: str, default = `no_bias` + Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. + core_attention_bias_shape: str, default = `1hss` + Attention bias shape, {`1hss`, `b1ss`, `bhss`}. + core_attention_bias_requires_grad: bool, default = `True` + Whether attention bias requires gradient. + pad_between_seqs: bool, default = `False` + Whether there is padding between sequences in a batch. + This only applies to `qkv_format=thd`. + attention_dropout: float, default = 0.0 + Attention dropout. + context_parallel: bool, default = `False` + Whether context parallelism is used or not. + deterministic: bool, default = `False` + Whether to run `DotProductAttention` with determinism or not. + is_training: bool, default = `True` + Whether in training mode (`True`) or inference mode (`False`) + fp8: bool, default = `False` + Whether `DotProductAttention` is in an `fp8_autocast` region. + fp8_meta: Optional[Dict[str Any]], default = `None` + The FP8 metadata tensor of `DotProductAttention`. + """ + + qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor + qkv_dtype: torch.dtype = torch.bfloat16 + qkv_layout: str = "sbh3d" + batch_size: int = 1 + num_heads: int = 16 + num_gqa_groups: int = 16 + max_seqlen_q: int = 128 + max_seqlen_kv: int = 128 + head_dim_qk: int = 64 + head_dim_v: int = 64 + attn_mask_type: str = "no_mask" + window_size: Union[Tuple[int, int], None] = None + alibi_slopes_shape: Union[torch.Size, List, None] = None + core_attention_bias_type: str = "no_bias" + core_attention_bias_shape: str = "1hss" + core_attention_bias_requires_grad: bool = True + pad_between_seqs: bool = False + attention_dropout: float = 0.0 + context_parallel: bool = False + deterministic: bool = False + is_training: bool = True + fp8: bool = False + fp8_meta: Union[Dict[str, Any], None] = None + + def __eq__(self, other): + """ + Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared, + since all other entries of fp8_meta are unused in get_attention_backend. + """ + if not isinstance(other, self.__class__): + return NotImplemented + for field in fields(self): + fname = field.name + sf = getattr(self, fname) + of = getattr(other, fname) + if fname != "fp8_meta": + if sf != of: + return False + elif sf.get("recipe", None) != of.get("recipe", None): + return False + return True + + +def get_attention_backend( + attention_params: AttentionParams = None, +): + """ + Select the appropriate attention backend/sub-backend based on user input and runtime environment. + + Parameters + ---------- + See `AttentionParams`. + + Returns + ---------- + use_flash_attention: bool + Whether the `FlashAttention` backend has been selected. + use_fused_attention: bool + Whether the `FusedAttention` backend has been selected. + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`. + use_unfused_attention: bool + Whether the `UnfusedDotProductAttention` backend has been selected. + available_backends: List[bool] + All available backends that could support the provided input. A list of Booleans + in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. + """ + # NOTE: As part of refactoring attention.py, populating the _attention_backends cache in attention + # is no longer performed at the end of get_attention_backend(), but the responsibility of doing so + # is shifted over to the caller of this function + qkv_type = attention_params.qkv_type + qkv_dtype = attention_params.qkv_dtype + qkv_layout = attention_params.qkv_layout + batch_size = attention_params.batch_size + num_heads = attention_params.num_heads + num_gqa_groups = attention_params.num_gqa_groups + max_seqlen_q = attention_params.max_seqlen_q + max_seqlen_kv = attention_params.max_seqlen_kv + head_dim_qk = attention_params.head_dim_qk + head_dim_v = attention_params.head_dim_v + attn_mask_type = attention_params.attn_mask_type + window_size = attention_params.window_size + alibi_slopes_shape = attention_params.alibi_slopes_shape + core_attention_bias_type = attention_params.core_attention_bias_type + core_attention_bias_shape = attention_params.core_attention_bias_shape + core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad + pad_between_seqs = attention_params.pad_between_seqs + attention_dropout = attention_params.attention_dropout + context_parallel = attention_params.context_parallel + deterministic = attention_params.deterministic + is_training = attention_params.is_training + fp8 = attention_params.fp8 + fp8_meta = attention_params.fp8_meta + + # Run config + logger = logging.getLogger("DotProductAttention") + logger.setLevel(AttentionLogging._log_level) + if not logger.hasHandlers(): + logger.addHandler(AttentionLogging._stream_handler) + device_compute_capability = get_device_compute_capability() + cudnn_version = get_cudnn_version() + run_config = { + "transformer_engine_version": te.__version__, + "compute_capability": "sm" + + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "flash_attn_version": ( + str(FlashAttentionUtils.version) + if FlashAttentionUtils.is_installed + else "not installed" + ), + "flash_attn_3_version": ( + str(FlashAttentionUtils.fa3_version) + if FlashAttentionUtils.v3_is_installed + else "not installed" + ), + "cudnn_version": ".".join([str(i) for i in cudnn_version]), + } + attention_params_dict = { + field.name: getattr(attention_params, field.name) for field in fields(attention_params) + } + run_config.update(attention_params_dict) + if fp8: + run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + logger.debug("Running with config=%s", run_config) + + # The following sections check if `FlashAttention` supports the provided attention params, + # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is + # necessary for performance/functionality, a warning will be issued to prompt users to + # install an appropriate FA version. + + # Filter: Environment variables + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + if not use_flash_attention and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_fused_attention: + logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if not use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + + # Filter: Compute capability + if device_compute_capability < (8, 0): + if use_flash_attention and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention as it requires compute capability sm80+") + use_flash_attention = False + if use_fused_attention: + logger.debug("Disabling FusedAttention as it requires compute capability sm80+") + use_fused_attention = False + if device_compute_capability < (9, 0): + if use_flash_attention and FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") + FlashAttentionUtils.use_v3 = False + + # Filter: Data type + if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ + torch.Tensor, + Float8Tensor, + ]: + if use_flash_attention and FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_flash_attention = False + if use_fused_attention: + logger.debug( + "Disabling FusedAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_fused_attention = False + + # Filter: Execution type + if fp8 and fp8_meta["recipe"].fp8_dpa: + if use_flash_attention and not FlashAttentionUtils.use_v3: + if FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") + use_flash_attention = False + if use_flash_attention and FlashAttentionUtils.use_v3 and is_training: + logger.debug( + "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" + ) + use_flash_attention = False + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") + use_unfused_attention = False + + # Filter: Head dimension + if use_flash_attention and head_dim_qk != head_dim_v: + if FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention as it does not support MLA.") + use_flash_attention = False + if use_flash_attention and ( + head_dim_qk > 256 + or head_dim_qk % 8 != 0 + or ( + head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) + ): + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " + "head_dim_qk <= 256 (>192 requires sm80/90/100+). " + "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", + head_dim_qk, + head_dim_v, + ".".join([str(i) for i in device_compute_capability]), + ) + use_flash_attention = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False + + # Filter: QKV layout + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") + use_unfused_attention = False + if use_flash_attention and pad_between_seqs: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention for qkv_format = thd when there is " + "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" + ) + use_flash_attention = False + + # Filter: Dropout + if attention_dropout != 0.0 and use_flash_attention and FlashAttentionUtils.use_v3: + logger.debug("Disabling FlashAttention 3 for dropout") + FlashAttentionUtils.use_v3 = False + + # Filter: Context parallelism + # qkv_format | attn_mask_type | attn_bias_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention + # | no_mask, causal | | + # | cross-attention: | | + # | no_mask | | + # thd | self-attention: | no_bias | FlashAttention, FusedAttention + # | padding, padding_causal | | if no padding between sequences, + # | cross-attention: | | FusedAttention + # | padding | | if there is padding between sequences + # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v. + if context_parallel and use_unfused_attention: + logger.debug( + "Disabling UnfusedDotProductAttention as it does not support context parallelism" + ) + use_unfused_attention = False + if context_parallel and use_flash_attention: + if fp8 and fp8_meta["recipe"].fp8_dpa: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with FP8" + ) + use_flash_attention = False + if "bottom_right" in attn_mask_type: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_flash_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal masking for cross-attention" + ) + use_flash_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with bias" + " type of %s", + core_attention_bias_type, + ) + use_flash_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " attention bias for THD format" + ) + use_flash_attention = False + + if context_parallel and use_fused_attention: + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_fused_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with causal" + " masking for cross-attention" + ) + use_fused_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias type" + " of %s", + core_attention_bias_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with attention" + " bias for THD format" + ) + use_fused_attention = False + elif head_dim_qk != head_dim_v: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with MLA" + ) + use_fused_attention = False + + # Filter: Attention mask + # attn_mask_type | attention_mask | supported backends + # ---------------------------------------------------------------------------------------- + # no_mask | None | All + # padding | | All + # self-attention | One tensor in shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors in shapes | + # | [b, 1, 1, sq] and [b, 1, 1, skv] | + # causal | None | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # padding_causal | Same as "padding" | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # causal_bottom_right | None | All + # padding_causal_bottom_right | Same as "padding" | All + # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # | [b, h, sq, skv] | + if attn_mask_type == "arbitrary": + if use_flash_attention and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention for arbitrary mask") + use_flash_attention = False + if use_fused_attention: + logger.debug("Disabling FusedAttention for arbitrary mask") + use_fused_attention = False + if ( + use_flash_attention + and FlashAttentionUtils.use_v3 + and attn_mask_type in ["causal", "padding_causal"] + and max_seqlen_q != max_seqlen_kv + ): + logger.warning( + "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + FlashAttentionUtils.use_v3 = False + if ( + use_flash_attention + and attn_mask_type in ["causal", "padding_causal"] + and max_seqlen_q != max_seqlen_kv + ): + if FlashAttentionUtils.v2_1_plus: + logger.warning( + "Disabling FlashAttention as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + use_flash_attention = False + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.max_version = PkgVersion("2.1") + if ( + use_flash_attention + and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] + and max_seqlen_q != max_seqlen_kv + ): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.1") + elif not FlashAttentionUtils.v2_1_plus and not FlashAttentionUtils.use_v3: + logger.warning( + "Disabling FlashAttention as it only supports top-left-diagonal " + "causal mask before flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + use_flash_attention = False + if ( + use_flash_attention + and FlashAttentionUtils.use_v3 + and fp8 + and fp8_meta["recipe"].fp8_dpa + and "padding" in attn_mask_type + ): + logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") + FlashAttentionUtils.use_v3 = False + + # Filter: Sliding window attention + # backend | window_size | diagonal alignment + # --------------------------------------------------------------------------------- + # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right + # FusedAttention | (-1, 0) or (>=0, 0) | top left + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # | | converts window_size to an 'arbitrary' mask + if window_size is None: + window_size = check_set_window_size(attn_mask_type, window_size) + else: + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention" + " for FP8" + ) + use_fused_attention = False + elif window_size[1] != 0 or attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "with (left, 0) and no dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if FlashAttentionUtils.use_v3: + logger.debug( + "Disabling FlashAttention 3 as it does not support sliding window attention" + ) + FlashAttentionUtils.use_v3 = False + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention = False + + # Filter: Attention bias + # backend | bias types | ALiBi diagonal alignment + # --------------------------------------------------------------------------------- + # FlashAttention | no_bias, alibi/alibi_slopes | bottom right + # FusedAttention | no_bias, post_scale_bias | + # | alibi/alibi_slopes | top left, + # | | bottom_right (converts to a 'post_scale_bias' bias) + # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | + # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias + if use_flash_attention and core_attention_bias_type == "alibi": + if FlashAttentionUtils.use_v3: + logger.debug("Disabling FlashAttention 3 for ALiBi") + FlashAttentionUtils.use_v3 = False + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.4") + elif not FlashAttentionUtils.v2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention = False + + if use_flash_attention and ( + core_attention_bias_type not in ["no_bias", "alibi"] + or core_attention_bias_shape is not None + ): + if FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention for pre/post_scale_bias") + use_flash_attention = False + + fu_core_attention_bias_type = core_attention_bias_type + fu_core_attention_bias_shape = core_attention_bias_shape + fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad + if ( + use_fused_attention + and core_attention_bias_type == "alibi" + and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + ): + fu_core_attention_bias_type = "post_scale_bias" + fu_core_attention_bias_requires_grad = False + if alibi_slopes_shape is None: + fu_core_attention_bias_shape = "1hss" + elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + fu_core_attention_bias_shape = "1hss" + elif ( + len(alibi_slopes_shape) == 2 + and alibi_slopes_shape[0] == batch_size + and alibi_slopes_shape[1] == num_heads + ): + fu_core_attention_bias_shape = "bhss" + + if ( + use_fused_attention + and fu_core_attention_bias_type == "post_scale_bias" + and fu_core_attention_bias_shape != "1hss" + ): + if fu_core_attention_bias_requires_grad: + # remove this line when cuDNN adds bwd support for + # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] + logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") + use_fused_attention = False + else: + # max512 backend will only support [1, h, s, s] + os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" + + # Filter: cuDNN support + fused_attention_backend = None + if use_fused_attention: + q_type = TE_DType[qkv_dtype] + kv_type = q_type + if fp8 and fp8_meta["recipe"].fp8_dpa: + q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + kv_type = q_type + fused_attention_backend = tex.get_fused_attn_backend( + q_type, + kv_type, + QKVLayout[qkv_layout], + AttnBiasType[fu_core_attention_bias_type], + AttnMaskType[attn_mask_type], + attention_dropout, + num_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size[0], + window_size[1], + ) + if fused_attention_backend == FusedAttnBackend["No_Backend"]: + logger.debug("Disabling FusedAttention as no backend supports the provided input") + use_fused_attention = False + fused_attention_backend = None + if ( + use_fused_attention + and window_size is not None + and window_size[0] != -1 + and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + ): + logger.debug( + "Disabling FusedAttention as only sub-backend %s does not support " + "slidng window attention", + int(fused_attention_backend), + ) + use_fused_attention = False + fused_attention_backend = None + if ( + use_fused_attention + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] + and fu_core_attention_bias_type == "post_scale_bias" + and fu_core_attention_bias_shape != "1hss" + ): + logger.debug( + "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in" + " [1, H, S, S] shape" + ) + use_fused_attention = False + fused_attention_backend = None + + # Filter: Determinism + # backend | deterministic + # --------------------------------------------- + # FlashAttention | + # flash-attn >=2.0, <2.4.1 | no + # flash-attn >=2.4.1 | yes + # FusedAttention | + # sub-backend 0 | yes + # sub-backend 1 | workspace optimization path and sm90+: yes; + # | otherwise: no + # sub-backend 2 | no + # UnfusedDotProductAttention | yes + if use_flash_attention and deterministic: + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.4.1") + elif not FlashAttentionUtils.v2_4_1_plus and not FlashAttentionUtils.use_v3: + logger.warning( + "Disabling FlashAttention as version <2.4.1 does not support deterministic " + "execution. To use FlashAttention with deterministic behavior, " + "please install flash-attn >= 2.4.1." + ) + use_flash_attention = False + if use_fused_attention and deterministic: + if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: + logger.debug("Disabling FusedAttention for determinism reasons") + use_fused_attention = False + if ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + and is_training + and ( + device_compute_capability < (9, 0) + or core_attention_bias_requires_grad + or cudnn_version < (8, 9, 5) + ) + ): + logger.debug("Disabling FusedAttention for determinism reasons") + use_fused_attention = False + + # All available backends + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + + # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. + # When `FusedAttention` does not support the provided attention params, and `FlashAttention` + # does, we recommend users to install flash-attn if not installed already. + if not use_fused_attention and use_flash_attention and not FlashAttentionUtils.is_installed: + logger.warning( + "flash-attn may provide important feature support or performance improvement." + " Please install flash-attn %s.", + _get_supported_versions( + FlashAttentionUtils.version_required, + FlashAttentionUtils.max_version, + ), + ) + if use_flash_attention and not FlashAttentionUtils.is_installed: + use_flash_attention = False + available_backends[0] = False + + logger.debug( + "Available backends = {FlashAttention=%s, FusedAttention=%s%s," + " UnfusedDotProductAttention=%s}", + bool(available_backends[0]), + bool(available_backends[1]), + ( + f" (sub-backend {int(fused_attention_backend)})" + if fused_attention_backend is not None + else "" + ), + bool(available_backends[2]), + ) + + # Select FusedAttention for performance + if ( + use_flash_attention + and use_fused_attention + and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + ): + if device_compute_capability >= (9, 0): + logger.debug( + "Disabling FlashAttention to give FusedAttention preference on Hopper+ " + "for performance reasons" + ) + use_flash_attention = False + if ( + use_flash_attention + and use_fused_attention + and fused_attention_backend == FusedAttnBackend["FP8"] + and FlashAttentionUtils.use_v3 + ): + logger.debug( + "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " + "in FP8 execution" + ) + use_flash_attention = False + + # Selected backend + if use_flash_attention: + use_fused_attention = False + use_unfused_attention = False + elif use_fused_attention: + use_unfused_attention = False + selected_backend = "NoBackend" + if use_flash_attention: + selected_backend = "FlashAttention" + elif use_fused_attention: + selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" + elif use_unfused_attention: + selected_backend = "UnfusedDotProductAttention" + logger.debug("Selected backend = %s", selected_backend) + + """global _attention_backends + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False""" + + return ( + use_flash_attention, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) + + +@torch.no_grad() +def get_full_mask( + max_seqlen_q: int, + max_seqlen_kv: int, + attn_mask_type: str = "no_mask", + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + window_size: Tuple[int, int] = None, + attention_type: str = "self", + bottom_right_alignment: bool = True, +) -> torch.Tensor: + """ + Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, + `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends + on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: + + attn_mask_type output shape diagonal alignment + -------------------------------------------------------------------------------------------- + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right + arbitrary same as attention_mask follow bottom_right_alignment + + .. note:: + + For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right + diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, + i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, + max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( + [[False, False, True, True], [False, False, False, False]], + [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] + shape and is,:: + + [[[False, False, False, True], + [False, False, False, True], + [ True, True, True, True], + [ True, True, True, True]], + [[False, True, True, True], + [False, True, True, True], + [False, True, True, True], + [False, True, True, True]]] + + Parameters + ---------- + max_seqlen_q: int + Maximum sequence length for queries. + max_seqlen_kv: int + Maximum sequence length for keys and values. + attn_mask_type: str, default = `no_mask` + Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", + "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + default = `None` + Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention + for the requirements of `attention_mask` for different `attn_mask_type`s. + window_size: Tuple[int, int], default = `None` + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window and causal mask specifically. Both `causal` and `causal_bottom_right` masks + map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on + `attn_mask_type`. + attention_type: str, default = "self" + Attention type, {"self", "cross"} + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the sliding window attention to the bottom right (`True`) + or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly + specifies "causal" or "causal_bottom_right". + + Returns + ---------- + attn_mask_type: str + For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` + attention_mask: torch.Tensor + The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` + actual_seqlens_q: torch.Tensor + For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For other masks, `None`. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + For other masks, `None`. + """ + # perform basic checks + change_type = window_size is not None and ( + window_size[0] != -1 or window_size[1] not in [-1, 0] + ) + if window_size is None: + window_size = (-1, -1) + if "causal" in attn_mask_type: + window_size = (window_size[0], 0) + window_size = ( + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1], + ) + + # apply padding mask + actual_seqlens_q = None + actual_seqlens_kv = None + if "padding" in attn_mask_type: + if attention_type == "self": + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + m = attention_mask.logical_not() + actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) + actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + + # apply SWA mask + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + swa_left = None + swa_right = None + if attn_mask_type == "causal_bottom_right" or ( + attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment + ): + swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] + swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] + elif attn_mask_type in ["causal", "padding_causal"] or ( + attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment + ): + swa_left = mask - window_size[0] + swa_right = mask + window_size[1] + elif attn_mask_type == "padding_causal_bottom_right" or ( + attn_mask_type == "padding" and bottom_right_alignment + ): + batch_size = attention_mask.shape[0] + swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0] + ).view(batch_size, 1, 1, 1) + swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1] + ).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not( + torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) + ) + if attention_mask is not None: + attention_mask = torch.logical_or(swa_mask, attention_mask) + else: + attention_mask = swa_mask + + # change mask type + if change_type: + attn_mask_type = "arbitrary" + + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv + + +@torch.no_grad() +def get_alibi( + _alibi_cache: Dict[str, Any], + num_heads: int, + max_seqlen_q: int, + max_seqlen_kv: int, + actual_seqlens_q: Optional[torch.Tensor] = None, + actual_seqlens_kv: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + bias_dtype: Optional[torch.dtype] = None, + bottom_right_alignment: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + num_heads: int + Number of heads. + max_seqlen_q: int + Maximum sequence length for queries. + max_seqlen_kv: int + Maximum sequence length for keys and values. + actual_seqlens_q: Optional[torch.Tensor], default = `None` + Actual sequence lengths for queries, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + Actual sequence lengths for keys and values, in shape [batch_size]. + alibi_slopes: Optional[torch.Tensor], default = `None` + Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. + bias_dtype: Optional[torch.dtype], default = `None` + Dtype of the generated ALiBi bias. If None, use torch.float32. + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the ALiBi bias to the bottom right corner of + the matrix (`True`) or top left (`False`). + + Returns + ---------- + alibi_slopes: torch.Tensor + ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. + alibi_bias: torch.Tensor + ALiBi bias in FP32 or `bias_dtype`. Its shape is + (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, + and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or + (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in + [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and + `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. + """ + # NOTE: As part of refactoring attention.py, get_alibi() now receives the alibi cache from the caller + # as an additional input arg + if _alibi_cache["_alibi_slopes_require_update"]: + if alibi_slopes is not None: + _alibi_cache["_alibi_slopes"] = alibi_slopes + else: + n = 2 ** math.floor(math.log2(num_heads)) + m_0 = 2.0 ** (-8.0 / n) + m = torch.pow(m_0, torch.arange(1, 1 + n)) + + if n < num_heads: + m_hat_0 = 2.0 ** (-4.0 / n) + m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) + m = torch.cat([m, m_hat]) + + _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda") + _alibi_cache["_num_heads"] = num_heads + _alibi_cache["_alibi_slopes_require_update"] = False + + if _alibi_cache["_alibi_bias_require_update"]: + assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!" + if _alibi_cache["_alibi_slopes"].dim() == 1: + slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) + elif _alibi_cache["_alibi_slopes"].dim() == 2: + slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) + else: + raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") + + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if actual_seqlens_q is None and actual_seqlens_kv is None: + if bottom_right_alignment: + bias = bias + max_seqlen_kv - max_seqlen_q + elif actual_seqlens_q is not None and actual_seqlens_kv is not None: + batch_size = actual_seqlens_q.shape[0] + bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + if bottom_right_alignment: + bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + else: + assert ( + False + ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" + bias = bias.abs().mul(-1) + bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) + _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv + _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment + bias_dtype = torch.float32 if bias_dtype is None else bias_dtype + _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") + _alibi_cache["_alibi_bias_require_update"] = False + + return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"] + + +def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: + """ + Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 + tensor of shape [batch_size + 1] containing the cumulative sequence lengths of + the samples in a batch. + """ + mask = mask.squeeze(1).squeeze(1) + reduced_mask = mask.logical_not().sum(dim=1) + cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + return cu_seqlens + + +def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 + tensor of shape [batch_size + 1] containing the cumulative sequence lengths of + the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1] + containing the indices for the valid tokens. + """ + mask = mask.squeeze(1).squeeze(1) + bs, seqlen = mask.shape + + reduced_mask = mask.logical_not().sum(dim=1) + cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + mask = mask.reshape(-1) + indices = mask.logical_not().nonzero() + indices = indices.unsqueeze(-1) + + num_nonzeros = indices.shape[0] + pad_amount = bs * seqlen - num_nonzeros + indices = F.pad( + input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen) + ) + + return cu_seqlens, indices + + +def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: + """ + Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32 + tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for + the valid tokens in a batch. + """ + bs = len(cu_seqlens) - 1 + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] + indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") + + num_nonzeros = indices.shape[0] + pad_amount = bs * max_seqlen - num_nonzeros + indices = F.pad( + input=indices, + pad=(0, 0, 0, 0, 0, pad_amount), + mode="constant", + value=float(bs * max_seqlen), + ) + + return indices + + +_cu_seqlens_cache = {} + + +def get_full_cu_seqlens( + batch_size: int, + max_seqlen: int, + device: torch.device, +) -> torch.Tensor: + """Cumulative sequence lengths in full data batch + + All sequences in batch have the maximum sequence length. + + """ + global _cu_seqlens_cache + if (batch_size, max_seqlen) not in _cu_seqlens_cache: + _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( + 0, + (batch_size + 1) * max_seqlen, + step=max_seqlen, + dtype=torch.int32, + device=device, + ) + return _cu_seqlens_cache[(batch_size, max_seqlen)] + + +@jit_fuser +def _pack_tensor( + indices: torch.Tensor, + tensor: torch.Tensor, +) -> torch.Tensor: + """ + Packs the given tensor using the `indices`. + """ + padding_indice = torch.zeros( + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device + ) + indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) + if isinstance(tensor, Float8Tensor): + tensor_data = torch.cat((tensor._data, padding_indice), dim=0) + gathered_data = torch.gather(tensor_data, 0, indices) + + packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape) + else: + tensor = torch.cat((tensor, padding_indice), dim=0) + + packed = torch.gather(tensor, 0, indices) + return packed + + +@jit_fuser +def _pack_2_tensors( + indices: torch.Tensor, + t1: torch.Tensor, + t2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Packs the given 2 tensors using the `indices`. + """ + t1_packed = _pack_tensor(indices, t1) + t2_packed = _pack_tensor(indices, t2) + return t1_packed, t2_packed + + +@jit_fuser +def _pack_3_tensors( + indices: torch.Tensor, + t1: torch.Tensor, + t2: torch.Tensor, + t3: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Packs the given 3 tensors using the `indices`. + """ + t1_packed = _pack_tensor(indices, t1) + t2_packed = _pack_tensor(indices, t2) + t3_packed = _pack_tensor(indices, t3) + return t1_packed, t2_packed, t3_packed + + +@jit_fuser +def _unpack_tensor( + indices: torch.Tensor, + dim0: int, + tensor: torch.Tensor, +) -> torch.Tensor: + """ + Inverse of `_pack_tensor`. + """ + indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) + unpacked = torch.zeros( + dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device + ) + if isinstance(tensor, Float8Tensor): + unpacked.scatter_(0, indices, tensor._data) + unpacked_data = unpacked[0:-1, :, :] + unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape) + else: + unpacked.scatter_(0, indices, tensor) + unpacked = unpacked[0:-1, :, :] + return unpacked + + +@jit_fuser +def _unpack_2_tensors( + indices: torch.Tensor, + dim0: int, + t1: torch.Tensor, + t2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Inverse of `_pack_2_tensors`. + """ + t1_unpacked = _unpack_tensor(indices, dim0, t1) + t2_unpacked = _unpack_tensor(indices, dim0, t2) + return t1_unpacked, t2_unpacked + + +@jit_fuser +def _unpack_3_tensors( + indices: torch.Tensor, + dim0: int, + t1: torch.Tensor, + t2: torch.Tensor, + t3: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Inverse of `_pack_3_tensors`. + """ + t1_unpacked = _unpack_tensor(indices, dim0, t1) + t2_unpacked = _unpack_tensor(indices, dim0, t2) + t3_unpacked = _unpack_tensor(indices, dim0, t3) + return t1_unpacked, t2_unpacked, t3_unpacked + + +class PackTensors(torch.autograd.Function): + """ + Autograd function to pack a tensor. + """ + + @staticmethod + def forward( + ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # pylint: disable=missing-function-docstring + assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." + ctx.save_for_backward(indices) + ctx.dim0 = tensors[0].shape[0] + if len(tensors) == 1: + return _pack_tensor(indices, *tensors) + if len(tensors) == 2: + return _pack_2_tensors(indices, *tensors) + return _pack_3_tensors(indices, *tensors) + + @staticmethod + def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): + # pylint: disable=missing-function-docstring + (indices,) = ctx.saved_tensors + if len(grad_outputs) == 1: + return None, _unpack_tensor(indices, ctx.dim0, *grad_outputs) + if len(grad_outputs) == 2: + return None, *_unpack_2_tensors(indices, ctx.dim0, *grad_outputs) + return None, *_unpack_3_tensors(indices, ctx.dim0, *grad_outputs) + + +class UnpackTensor(torch.autograd.Function): + """ + Autograd function to unpack a tensor. + """ + + @staticmethod + def forward( + ctx, + indices: torch.Tensor, + dim0: int, + tensor: torch.Tensor, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + ctx.save_for_backward(indices) + return _unpack_tensor(indices, dim0, tensor) + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + (indices,) = ctx.saved_tensors + return None, None, _pack_tensor(indices, grad_output) + + +def get_qkv_layout( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str = "sbhd", +) -> str: + """Get qkv layout. + + Parameters + ---------- + q: torch.Tensor + Query tensor. + k: torch.Tensor + Key tensor. + v: torch.Tensor + Value tensor. + qkv_format: str, default = `sbhd` + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for + the sequence length dimension, `b` batch size, `h` the number of attention heads, + `d` head size, and `t` the total number of tokens in a batch, i.e. + `t = sum(s_i) for i = 0...b-1`. + + Returns + ---------- + qkv_layout: str + Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five + memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk + of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means + `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v` + are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and + `v = kv[:,:,:,1,:]`. + Mapping: + `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} + `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} + `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + q: torch.Tensor + Query tensor. It may be different from input `q` as we try to fit tensors to + a supported layout. + k: torch.Tensor + Key tensor. It may be different from input `k` as we try to fit tensors to + a supported layout. + v: torch.Tensor + Value tensor. It may be different from input `v` as we try to fit tensors to + a supported layout. + """ + + check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) + assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" + + def run_iteratively(q, k, v): + # check data pointers + data_ptr = q.untyped_storage().data_ptr() + check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) + data_ptr = k.untyped_storage().data_ptr() + check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + + # check tensor shapes + shape = q.shape + check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) + shape = k.shape + check_shapes_kv = shape[:-1] == v.shape[:-1] + + # check tensor strides + stride = q.stride() + check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) + check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( + sv / v.shape[-1] for sv in v.stride()[:-1] + ) + + # check tensor offsets for h3d and 3hd layouts + prod_h_d = q.shape[-1] * q.shape[-2] + check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v])) + check_h3d_offsets = all( + x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v]) + ) + + # check tensor offsets for hd_h2d and hd_2hd layouts + prod_all_dims = [np.prod(x.shape) for x in [q, k]] + offset = prod_all_dims[0] if check_ptrs_qkv else 0 + prod_h_d = k.shape[-1] * k.shape[-2] + check_2hd_offsets = all( + x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v]) + ) + check_h2d_offsets = all( + x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v]) + ) + + # check tensor offsets for hd_hd_hd layouts + check_hd_offsets_qkv = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v])) + if check_ptrs_qkv + else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v])) + ) + check_hd_offsets_qk = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k])) + if not check_ptrs_qkv and check_ptrs_qk + else all(x.storage_offset() == 0 for i, x in enumerate([q, k])) + ) + check_hd_offsets_kv = ( + all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v])) + if not check_ptrs_qkv and check_ptrs_kv + else all(x.storage_offset() == 0 for i, x in enumerate([k, v])) + ) + + if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets: + # sb3hd, bs3hd, t3hd + # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv + qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] + elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets: + # sbh3d, bsh3d, th3d + # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv + qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets: + # sbhd_sb2hd, bshd_bs2hd, thd_t2hd + # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets: + # sbhd_sbh2d, bshd_bsh2d, thd_th2d + # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True + qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] + elif ( + check_strides_kv + and check_shapes_kv + and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) + ): + # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd + # three chunks of memory, q, k and v, which may be disjoint or consecutive, and + # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or + # check_ptrs_qk=True or check_ptrs_kv=True + qkv_layout = "_".join(list([qkv_format]) * 3) + else: + qkv_layout = "not_supported" + + return qkv_layout + + qkv_layout = run_iteratively(q, k, v) + if qkv_layout == "not_supported": + # force q,k,v to be contiguous and run get_layout again + q, k, v = [x.contiguous() for x in [q, k, v]] + qkv_layout = run_iteratively(q, k, v) + if qkv_layout == "not_supported": + raise RuntimeError("The provided qkv memory layout is not supported!") + + return qkv_layout, q, k, v + + +def check_set_window_size( + attn_mask_type: str, + window_size: Tuple[int, int] = None, +): + """Check if sliding window size is compliant with attention mask type. + If not, set it to the appropriate size. + + attn_mask_type | window_size + ------------------------------------------------------------------------- + no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0) + causal, padding_causal | (-1, 0) or (>=0, 0) + causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) + """ + orig_window_size = window_size + if "causal" in attn_mask_type: + if orig_window_size is None: + window_size = (-1, 0) + elif orig_window_size == (-1, -1) or ( + orig_window_size[0] >= 0 and orig_window_size[1] != 0 + ): + window_size = (orig_window_size[0], 0) + warnings.warn( + "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type + ) + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): + assert False, ( + "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type + ) + elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: + if orig_window_size is None: + window_size = (-1, -1) + elif orig_window_size == (-1, 0): + window_size = (-1, -1) + warnings.warn( + "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type + ) + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): + assert False, ( + "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type + ) + else: + assert False, "Invalid attn_mask_type: " + attn_mask_type + return window_size + + +def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): + """Get the list of quantizers used in attention from the quantizers list.""" + if not fp8: + num_of_nones = 8 if cp_specific_quantizers else 6 + return [None] * num_of_nones + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] + QKV_quantizer.internal = True + QKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = True + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True + dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] + dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_CP_quantizer.internal = True + O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] + O_CP_quantizer.set_usage(rowwise=True, columnwise=False) + + if cp_specific_quantizers: + return ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index fbc787d6d2..d829275777 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -12,10 +12,10 @@ from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.attention import ( - InferenceParams, MultiheadAttention, - check_set_window_size, ) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.utils import check_set_window_size from transformer_engine.pytorch.jit import ( set_jit_fusion_options, warmup_jit_bias_dropout_add_all_dtypes, From 2b1b72fa4f912e08c07f82805fdacf044c225bb7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Mar 2025 00:30:07 +0000 Subject: [PATCH 222/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 32 +++++++++++++------------ 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b2a8565ffa..0278d75a13 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2230,9 +2230,7 @@ def backward(ctx, dout): dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): + if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) elif fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -2695,9 +2693,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -3853,13 +3849,15 @@ def forward( attention_mask = get_padding_mask( batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv ) - attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = dpa_utils.get_full_mask( - max_seqlen_q, - max_seqlen_kv, - attn_mask_type=attn_mask_type, - attention_mask=attention_mask, - window_size=window_size, - attention_type=self.attention_type, + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( + dpa_utils.get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + ) ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] @@ -4203,7 +4201,9 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(attention_mask) + cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( + attention_mask + ) else: indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) cu_seqlens_kv = cu_seqlens_q @@ -4215,7 +4215,9 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(attention_mask[0]) + cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( + attention_mask[0] + ) cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices( attention_mask[1] ) From a7eeb28bd917a647abf7854fa22239b8ee85c2af Mon Sep 17 00:00:00 2001 From: Li Tao Date: Sat, 15 Mar 2025 08:39:20 +0800 Subject: [PATCH 223/239] [PyTorch] Support TP Overlap in Per-Tensor Current Scaling Recipe (#1554) * support tp-comm-overlap in Current Scaling recipe Signed-off-by: Li Tao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean Signed-off-by: Li Tao * fix test recipe argument to generalize to MXFP8 Signed-off-by: Li Tao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reduce duplicated transpose in certain cases Signed-off-by: Li Tao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use per_tensor_scaling() to judge DS or CS Signed-off-by: Li Tao * minor fixes Signed-off-by: Li Tao * change comment description Signed-off-by: Li Tao * add multi-layer unit test for tp overlap Signed-off-by: Li Tao * support test case that run for several times Signed-off-by: Li Tao * avoid save ub tensor in prepare_for_saving Signed-off-by: Li Tao * fix Signed-off-by: Li Tao * switch to a simple fix Signed-off-by: Li Tao * formatting Signed-off-by: Li Tao * simply test cases; avoid additional clone() Signed-off-by: Li Tao * fall back to get_buffer in layernormmlp Signed-off-by: Li Tao * use 2 layers for fp8 tpoverlap multi-layer test for better tolerance, limit max gpus for test Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Li Tao Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: zhongboz --- .../distributed/run_layer_with_overlap.py | 48 ++++- .../distributed/test_comm_gemm_overlap.py | 166 +++++++++++++++++- .../common/gemm/cublaslt_gemm.cu | 3 +- transformer_engine/common/recipe/__init__.py | 4 + .../pytorch/csrc/extensions/quantizer.cpp | 5 +- .../pytorch/module/layernorm_linear.py | 41 ++++- .../pytorch/module/layernorm_mlp.py | 44 +++-- transformer_engine/pytorch/module/linear.py | 34 ++-- .../pytorch/tensor/float8_tensor.py | 2 +- 9 files changed, 297 insertions(+), 50 deletions(-) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 3526ad812f..526876edf3 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -17,13 +17,25 @@ import torch.distributed as dist import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +class multi_module_model(torch.nn.Module): + def __init__(self, module, num_layers, *args, **kwargs): + super().__init__() + self.num_layers = num_layers + self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + def _te_layer_argtype(name): te_layers = [ te.Linear, @@ -40,10 +52,12 @@ def _te_layer_argtype(name): return layer_map[name.lower()] -def _get_layer_args(config, tp_group, tp_size, reference=False): +def _get_layer_args(config, tp_group, tp_size, num_layers, reference=False): hidden_size = config.num_heads * config.head_dim ffn_hidden_size = 4 * hidden_size qkv_size = 3 * hidden_size + if num_layers > 1 and config.layer_type != te.TransformerLayer: + raise ValueError("Stacked layers are only supported for te.TransformerLayer!") input_shape = [config.seq_length, config.batch_size, hidden_size] args = [hidden_size] kwargs = { @@ -106,6 +120,9 @@ def _parse_args(argv=None, namespace=None): description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers." ) parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP) + parser.add_argument( + "--num-layers", type=int, default=1, help="Number of identical layers to stack." + ) parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( @@ -142,6 +159,13 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) + parser.add_argument( + "--quantization", + type=str.lower, + default="none", + choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"], + help="Quantization recipe", + ) parser.add_argument( "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) @@ -341,7 +365,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Initialize the Transformer Engine layer with overlap - args, kwargs, input_shape = _get_layer_args(opts, nccl_world, opts.tp) + args, kwargs, input_shape = _get_layer_args( + opts, nccl_world, opts.tp, num_layers=opts.num_layers + ) # Intialize userbuffers ub_cfgs = None if opts.overlap_rs_dgrad: @@ -359,7 +385,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ) with te.fp8_model_init(enabled=opts.fp8_init): - test_model = opts.layer_type(*args, **kwargs) + test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs) dist_print("Initialized test model...", debug=True) if WORLD_RANK == 0: pprint.pprint(kwargs) @@ -367,9 +393,11 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist.barrier() # Initialize the reference model and copy all parameters - ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, opts.tp, reference=True) + ref_args, ref_kwargs, _ = _get_layer_args( + opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True + ) with te.fp8_model_init(enabled=opts.fp8_init): - ref_model = opts.layer_type(*ref_args, **ref_kwargs) + ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs) dist_print("Initialized reference model...", debug=True) for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): with torch.no_grad(): @@ -379,7 +407,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): # Fp8 recipe setup fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + fp8_recipe = None + if opts.quantization == "fp8_delayed_scaling": + fp8_recipe = DelayedScaling( + fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max" + ) + elif opts.quantization == "fp8_current_scaling": + fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) # Prepare random input tensors test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index eb6b5ca8ed..01400bba6b 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -30,8 +30,11 @@ ] MAX_LAYER_NAME_LENGTH = max([len(layer.__name__) for layer in TE_LAYERS]) +# to avoid numerical tolerance issues of doing comm gemm overlap, limit the number of GPUs used +MAX_GPUS_TO_USE = 4 + TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS: int = torch.cuda.device_count() +NUM_PROCS: int = min(torch.cuda.device_count(), MAX_GPUS_TO_USE) LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] if tex.ubuf_built_with_mpi(): LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python3"] @@ -83,7 +86,9 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): +def _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 +): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -93,6 +98,7 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, f"--num-heads={NUM_HEADS}", f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", + f"--num-layers={num_layers}", ] if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]: test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") @@ -104,6 +110,7 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") + test_cmd.append(f"--quantization={quantization}") os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" @@ -195,7 +202,65 @@ def test_bulk_overlaps(comm_type, fp8, connections): _run_gemm_with_overlap(comm_type, True, False, False, fp8) -@pytest.mark.parametrize("fp8", (False, True), ids=[" BF16 ", " FP8 "]) +@pytest.mark.parametrize( + "fp8", + (False,), + ids=[ + " BF16 ", + ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", + [ + (te.Linear.__name__, "row", False), + (te.Linear.__name__, "column", False), + (te.Linear.__name__, "column", True), + (te.LayerNormLinear.__name__, "row", False), + (te.LayerNormLinear.__name__, "column", False), + (te.LayerNormLinear.__name__, "column", True), + ] + + list( + zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + [None] * len(TE_LAYERS[2:]) * 2, + [False, True] * len(TE_LAYERS[2:]), + ) + ), + ids=[ + f" {te.Linear.__name__} - ROW-PARALLEL ", + f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", + ] + + [ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), + ) + ], +) +def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): + """ + Test Transformer Engine layers with comm+GEMM overlap. + """ + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None) + + +@pytest.mark.parametrize( + "quantization", + ["fp8_delayed_scaling", "fp8_current_scaling"], + ids=[" DELAYED SCALING ", " CURRENT SCALING "], +) +@pytest.mark.parametrize( + "fp8", + (True,), + ids=[ + " FP8 ", + ], +) @pytest.mark.parametrize( "layer_type,linear_parallel_mode,overlap_rs_dgrad", [ @@ -229,8 +294,99 @@ def test_bulk_overlaps(comm_type, fp8, connections): ) ], ) -def test_layers_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): +def test_layers_with_overlap_fp8( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization +): + """ + Test Transformer Engine layers with comm+GEMM overlap. + """ + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization) + + +@pytest.mark.parametrize( + "fp8", + (False,), + ids=[ + " BF16 ", + ], +) +@pytest.mark.parametrize( + "num_layers", + (2,), + ids=[ + " 2 layers ", + ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", + list( + zip( + [te.TransformerLayer.__name__ for _ in range(2)], + [None] * 2, + [False, True], + ) + ), + ids=[ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [te.TransformerLayer.__name__ for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"], + ) + ], +) +def test_multi_layer_with_overlap_bf16( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers +): + """ + Test Transformer Engine layers with comm+GEMM overlap. + """ + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers + ) + + +@pytest.mark.parametrize( + "quantization", + ["fp8_delayed_scaling", "fp8_current_scaling"], + ids=[" DELAYED SCALING ", " CURRENT SCALING "], +) +@pytest.mark.parametrize( + "fp8", + (True,), + ids=[ + " FP8 ", + ], +) +@pytest.mark.parametrize( + "num_layers", + (2,), + ids=[ + " 2 layers ", + ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", + list( + zip( + [te.TransformerLayer.__name__ for _ in range(2)], + [None] * 2, + [False, True], + ) + ), + ids=[ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [te.TransformerLayer.__name__ for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"], + ) + ], +) +def test_multi_layer_with_overlap_fp8( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers +): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8) + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers + ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 52fa89b914..39b887783b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -242,8 +242,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUDA_VERSION >= 12080 cublasLtMatmulMatrixScale_t scaling_mode; #endif - if ((is_delayed_tensor_scaling(inputA->scaling_mode) && - is_delayed_tensor_scaling(inputB->scaling_mode))) { + if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; void *B_scale_inverse = param.B_scale_inv; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 937383d5ec..50a0a10b5f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -77,6 +77,10 @@ def float8_current_scaling(self): """Whether the given recipe is (per-tensor) current scaling.""" return isinstance(self, Float8CurrentScaling) + def float8_per_tensor_scaling(self): + """Whether the given recipe is per-tensor scaling.""" + return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + @dataclass() class DelayedScaling(Recipe): diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 427bf294d3..3d55fc15d4 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -223,9 +223,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso } const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - //unlike delayed scaling, in current scaling, scale is not known, so scale_inv should be empty buffer - opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); - at::Tensor scale_inv = at::empty(scale_inv_torch_shape, opts); + // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. + at::Tensor scale_inv = at::reciprocal(scale); py::object ret; if (internal) { diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1b62f8d777..9c3c798e68 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -148,12 +148,12 @@ def forward( with_input_all_gather = parallel_mode == "column" and sequence_parallel if fp8: - if ( - any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() ): raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" ) if input_quantizer is None: @@ -177,9 +177,19 @@ def forward( columnwise=backward_needs_input, ) + # Reduce duplicated transpose in `_fix_gathered_fp8_transpose` + if ( + fp8 + and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad + ): + input_quantizer.set_usage(rowwise=True, columnwise=False) + ub_obj_fprop = None ln_out = None - if ub_overlap_ag_fprop: + # For DelayScaling, output of normalization will be in fp8. + # For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. + if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer): ub_obj_fprop = get_ub(ub_name + "_fprop") ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) elif with_quantized_norm: @@ -208,6 +218,14 @@ def forward( ln_out_return = ln_out if return_layernorm_output else None nvtx_range_pop(f"{nvtx_label}.norm") + # For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer. + # So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer. + if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer): + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out_local = ln_out + ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) + input_quantizer.quantize(ln_out_local, out=ln_out) + # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") @@ -371,7 +389,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out.clone() if ub_overlap_ag_fprop else ln_out, # avoid saving a UB buffer mu, rsigma, ) @@ -464,9 +482,10 @@ def backward( ) and (ctx.fp8_recipe is not None) ): - if not ctx.fp8_recipe.delayed(): + if not ctx.fp8_recipe.float8_per_tensor_scaling(): raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" ) saved_tensors = ctx.saved_tensors @@ -553,7 +572,11 @@ def backward( dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) if ctx.grad_output_quantizer is not None: - ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + # Reduce duplicated transpose, which is performed in grad_output.update_usage + if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + else: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1f167b5a7e..633690ba6a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -190,12 +190,12 @@ def forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) - if ( - any([ub_overlap_ag, ub_overlap_rs]) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + if any([ub_overlap_ag, ub_overlap_rs]) and not ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() ): raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" ) activation_func = _act_func( @@ -209,7 +209,7 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - # for standard fp8: layernorm output = FP8 + # for fp8 DelayedScaling: layernorm output = FP8 # only output of the linear is returned # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned @@ -237,9 +237,19 @@ def forward( columnwise=backwards_needs_fc1_input, ) + # Reduce duplicated transpose in `_fix_gathered_fp8_transpose` + if ( + fp8 + and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad + ): + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + ub_obj_lnout = None ln_out = None - if ub_overlap_ag: + # For DelayScaling, output of normalization will be in fp8. + # For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. + if ub_overlap_ag and not isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): ub_obj_lnout = get_ub("fc1_fprop") ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) elif not with_quantized_norm: @@ -263,6 +273,14 @@ def forward( ln_out_return = ln_out if return_layernorm_output else None + # For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer. + # So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer. + if ub_overlap_ag and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): + ub_obj_lnout = get_ub("fc1_fprop") + ln_out_local = ln_out + ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) + fc1_input_quantizer.quantize(ln_out_local, out=ln_out) + # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False @@ -589,9 +607,10 @@ def backward( ) and (ctx.fp8_recipe is not None) ): - if not ctx.fp8_recipe.delayed(): + if not ctx.fp8_recipe.float8_per_tensor_scaling(): raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" ) saved_tensors = ctx.saved_tensors @@ -658,10 +677,11 @@ def backward( # Prepare grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication if ctx.grad_fc2_output_quantizer is not None: - ctx.grad_fc2_output_quantizer.set_usage( - rowwise=True, - columnwise=True, - ) + # Reduce duplicated transpose, which is performed in grad_output.update_usage + if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): + ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False) + else: + ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True) ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 4c87396e3c..77b52dae26 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,12 +129,12 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) - if ( - any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) - and not FP8GlobalStateManager.get_fp8_recipe().delayed() + if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() ): raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" ) if input_quantizer is None: @@ -150,10 +150,17 @@ def forward( quantizer=input_quantizer, ) else: - input_quantizer.set_usage( - rowwise=True, - columnwise=backward_needs_input, - ) + if ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad + ): + # reduce duplicated transpose in `_fix_gathered_fp8_transpose` + input_quantizer.set_usage(rowwise=True, columnwise=False) + else: + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, + ) if not isinstance(inputmat, QuantizedTensor): inputmat = input_quantizer(inputmat) own_quantized_input = True @@ -364,9 +371,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) and (ctx.fp8_recipe is not None) ): - if not ctx.fp8_recipe.delayed(): + if not ctx.fp8_recipe.float8_per_tensor_scaling(): raise NotImplementedError( - "Comm+GEMM overlap is only supported with FP8 delayed scaling" + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" ) saved_tensors = ctx.saved_tensors @@ -445,7 +453,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication if ctx.grad_output_quantizer is not None: - ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + # Reduce duplicated transpose, which is performed in grad_output.update_usage + if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + else: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5bea0398ab..e45010bb00 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -213,7 +213,7 @@ def __init__( amax_epsilon: float = 0.0, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.scale = torch.empty(1, dtype=torch.float32, device=device) + self.scale = torch.ones(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype self.with_amax_reduction = with_amax_reduction From a6c8455b76bb91f6a06a2a414a8a19e8b07271e0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 18:30:34 -0700 Subject: [PATCH 224/239] fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 14 +- transformer_engine/pytorch/attention.py | 7 +- .../dot_product_attention/inference.py | 805 +++++++++++++++++- .../pytorch/dot_product_attention/utils.py | 511 +++++++---- transformer_engine/pytorch/inference.py | 794 ----------------- 5 files changed, 1106 insertions(+), 1025 deletions(-) delete mode 100644 transformer_engine/pytorch/inference.py diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index d23b9da897..e0b970620e 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -18,11 +18,9 @@ from transformer_engine.pytorch.transformer import ( TransformerLayer, ) -from transformer_engine.pytorch.attention import ( - DotProductAttention, - InferenceParams, - _flash_attn_3_is_installed, -) +from transformer_engine.pytorch.attention import DotProductAttention +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.utils import ( get_device_compute_capability, init_method_normal, @@ -411,7 +409,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda config = model_configs_infer[model] num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 # flash-attn v2 requires page_size >= 256 - if backend == "FlashAttention" and not _flash_attn_3_is_installed: + if backend == "FlashAttention" and not fa_utils.v3_is_installed: config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_kv = config.max_seqlen_kv config.max_seqlen_q = 256 @@ -422,7 +420,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda page_size = None total_num_pages = None if is_paged: - page_size = 256 if backend == "FlashAttention" and not _flash_attn_3_is_installed else 1 + page_size = 256 if backend == "FlashAttention" and not fa_utils.v3_is_installed else 1 config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) else: @@ -696,6 +694,6 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda sim.complete_times = sim.serving_times + sim.gen_lens sim.print_summary(logger) - if backend == "FlashAttention" and not _flash_attn_3_is_installed: + if backend == "FlashAttention" and not fa_utils.v3_is_installed: config.max_seqlen_q = config_max_seqlen_q config.max_seqlen_kv = config_max_seqlen_kv diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0278d75a13..165a86b4f7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -68,7 +68,7 @@ ) from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing -from transformer_engine.pytorch.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, prepare_for_saving, @@ -82,6 +82,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb + # Setup Attention Logging attn_log.setup_logging() @@ -3815,7 +3816,7 @@ def forward( ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" # get q_format and kv_format for training and inference - qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, _ = dpa_utils.get_qkv_format(qkv_layout, inference_params) if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) @@ -3846,7 +3847,7 @@ def forward( ) if "padding" in attn_mask_type and attention_mask is None: - attention_mask = get_padding_mask( + attention_mask = dpa_utils.get_padding_mask( batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv ) attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( diff --git a/transformer_engine/pytorch/dot_product_attention/inference.py b/transformer_engine/pytorch/dot_product_attention/inference.py index 6371bdab57..46e961b381 100644 --- a/transformer_engine/pytorch/dot_product_attention/inference.py +++ b/transformer_engine/pytorch/dot_product_attention/inference.py @@ -2,52 +2,793 @@ # # See LICENSE for license information. -""" -Inference classes for attention -""" +"""Inference""" +import logging +from collections import OrderedDict, defaultdict +from typing import Optional, List +from einops import rearrange +import torch -class InferenceParams: # pylint: disable=too-few-public-methods +import transformer_engine_torch as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat + +__all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"] + + +class KVCacheManager: + """Base KV cache manager""" + + def __init__(self): + """Initialize cache manager""" + self.cache = {} + self.sequences = OrderedDict() + + def reset(self): + """Reset cache manager state""" + self.sequences = OrderedDict() + + def allocate_memory(self, layer_number: int): + """Allocate memory for the cache""" + self.cache[layer_number] = (None, None) + + def pre_step( + self, + step_dict: OrderedDict, # pylint: disable=unused-argument + ): + """Update tracked sequences and prepare for step()""" + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, # pylint: disable=unused-argument + new_v: torch.Tensor, # pylint: disable=unused-argument + cu_new_seqlens: torch.Tensor, # pylint: disable=unused-argument + cu_cached_seqlens: torch.Tensor, # pylint: disable=unused-argument + qkv_format: str, # pylint: disable=unused-argument + ): + """Copy the new tokens to KV cache""" + return self.cache[layer_number] + + +class InferenceParams: """ - Inference parameters that are passed to the main model in order - to efficiently calculate and store the context during inference. + KV caching for inference. The memory allocation of the caches and the copying of new tokens + to the cache take place at the following locations.:: + + class TransformerLayer: + class MultiHeadAttention: + if self.layer_number not in inference_params.cache_manager.cache: + inference_params.allocate_memory(self.layer_number) + class DotProductAttention: + if inference_params is not None: + k_cache, v_cache, new_qkv_format = inference_params.step( + new_k, new_v, qkv_format) + output = attention(new_q, k_cache, v_cache, new_qkv_format) + + allocate_memory() can be called outside the model, independently. step() can take three formats, + qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both + NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the + backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}. + A standard KV caching workflow for inference is as follows.:: + + model = [TransformerLayer() for _ in range(num_layers)] + # initialize InferenceParams, e.g. with PagedKVCacheManager + inference_params = InferenceParams(..., is_paged=True) + # inference loop + for i in range(num_iters): + # get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1] + step_dict = OrderedDict(zip(seq_ids, step_lens)) + # update inference_params' state + inference_params.pre_step(step_dict) + # run iteration + output = model( + ..., + attn_mask_type="padding_causal", + cu_seqlens_q=cu_seqlens_new_q, + cu_seqlens_kv=cu_seqlens_new_kv, + inference_params=inference_params, + ) + # get output tokens based on qkv_format + # 'bshd': output = output[:,step_dict.values()-1] + # 'sbhd': output = output[step_dict.values()-1,:] + # 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1 + Parameters ---------- - max_batch_size : int - maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. + max_batch_size: int + Maximum batch size in inference + max_seqlen_kv: int + Maximum sequence length in inference + num_heads_kv: int + Number of attention heads in keys and values + head_dim_k: int + Head size for keys + dtype: torch.dtype + Data type of the KV cache + head_dim_v: int, default = None + Head size for values. If None, initialized as head_dim_k. + is_paged: bool, default = False + Whether the KV cache is paged (True) or non-paged (False) + total_num_pages: int, default = None + Total number of pages in the KV cache. Required for is_paged = True. + page_size: int, default = None + Page size of the KV cache. Required for is_paged = True. + max_ctx_len: int, default = None + Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. + qkv_format: str, default = "bshd" + Format of the incoming query/key/value tensors in current iteration + cache_manager: KVCacheManager, default = None + Custom cache manager, with KVCacheManager as the base class. """ - def __init__(self, max_batch_size, max_sequence_length): - self.max_sequence_length = max_sequence_length + def __init__( + self, + max_batch_size: int, + max_seqlen_kv: int, + num_heads_kv: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: int = None, + is_paged: bool = False, + total_num_pages: int = None, + page_size: int = None, + max_ctx_len: int = None, + qkv_format: str = "bshd", + cache_manager: KVCacheManager = None, + ): self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 - self.key_value_memory_dict = {} + self.max_seqlen_kv = max_seqlen_kv + self.num_heads_kv = num_heads_kv + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + self.is_paged = is_paged + + if not self.is_paged: + cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager + self.cache_manager = cls( + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + head_dim_v=self.head_dim_v, + ) + else: + assert page_size is not None, "Paged KV cache requires page_size is not None." + self.page_size = page_size + assert ( + max_seqlen_kv % page_size == 0 + ), "Paged KV cache requires max_seqlen_kv % page_size = 0." + max_pages_per_seq = max_seqlen_kv // page_size + assert ( + total_num_pages == self.max_batch_size * max_pages_per_seq + ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." + self.total_num_pages = total_num_pages + + cls = cache_manager if cache_manager is not None else PagedKVCacheManager + self.cache_manager = cls( + total_num_pages=self.total_num_pages, + page_size=self.page_size, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + head_dim_v=self.head_dim_v, + ) + + if qkv_format == "thd": + assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" + self.max_ctx_len = max_ctx_len + + self.cache_qkv_format = "bshd" + self.input_qkv_format = qkv_format + if self.input_qkv_format == self.cache_qkv_format: + self.output_qkv_format = self.cache_qkv_format + else: + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + + self.sequences_pre_step = OrderedDict() + self.sequences = OrderedDict() + self.batch_size = 0 + + self.cu_seqlens_q = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_kv = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def reset(self): + """Reset InferenceParams state""" + self.sequences = OrderedDict() + self.cache_manager.reset() + + def __repr__(self) -> str: + if self.is_paged: + return ( + f"dtype={self.dtype}, " + f"is_paged={self.is_paged}, " + f"total_pages={self.total_num_pages}, " + f"page_size={self.page_size}, " + f"num_heads={self.num_heads_kv}, " + f"head_dim_k={self.head_dim_k}, " + f"head_dim_v={self.head_dim_v}" + ) + return ( + f"dtype={self.dtype}, " + f"is_paged={self.is_paged}, " + f"max_batch_size={self.max_batch_size}, " + f"max_seqlen={self.max_seqlen_kv}, " + f"num_heads={self.num_heads_kv}, " + f"head_dim_k={self.head_dim_k}, " + f"head_dim_v={self.head_dim_v}" + ) - def swap_key_value_dict(self, batch_indices): + def allocate_memory(self, layer_number: int): """ - Reorders the KV cache using the specified batch indices. + Allocate memory for the cache. For layer layer_number, + - NonPagedKVCacheManager: + - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] + - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] + - PagedKVCacheManager: + - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] + - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] + """ + self.cache_manager.allocate_memory(layer_number) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + self.batch_size = len(step_dict) + + self.sequences = self.cache_manager.pre_step(step_dict) + # track the pre-step seqlens for the next layer in the model + self.sequences_pre_step = OrderedDict() + for k, v in self.sequences.items(): + self.sequences_pre_step[k] = v - step_dict[k] + + seqlens_q = list(step_dict.values()) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) + self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu")) + + seqlens_kv = list(self.sequences.values()) + cu_seqlens_kv = [0] + [sum(seqlens_kv[:i]) for i in range(1, self.batch_size + 1)] + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + self.max_batch_size - self.batch_size + ) + self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu")) + + def get_seqlens_pre_step(self): + """Get cached sequence lengths before the stepping""" + return torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + + def convert_paged_to_nonpaged(self, layer_number: int): + """ + Convert k_cache and v_cache from paged to non-paged format. Parameters ---------- - batch_indices : List[int] - Sequence of indices to reorder along the batch dimensions of - the KV cache. Must have a length equal to the batch size. + layer_number: int + Layer number of attention in the model + + Returns + ------- + k_cache: torch.Tensor + Non-paged key cache tensor + v_cache: torch.Tensor + Non-paged value cache tensor """ - if len(self.key_value_memory_dict) == 0: - raise ValueError("should not swap when dict in empty") + k_cache, v_cache = self.cache_manager.cache[layer_number] + page_table = self.cache_manager.page_table + batch_size = page_table.shape[0] + new_k_cache = rearrange( + k_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) + new_v_cache = rearrange( + v_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) - for layer_number, inference_memory in self.key_value_memory_dict.items(): - inference_key_memory, inference_value_memory = inference_memory - assert ( - len(batch_indices) == inference_key_memory.shape[1] - ) # make sure batch size is the same - new_inference_key_memory = inference_key_memory[:, batch_indices] - new_inference_value_memory = inference_value_memory[:, batch_indices] - self.key_value_memory_dict[layer_number] = ( - new_inference_key_memory, - new_inference_value_memory, + new_k_cache = new_k_cache[: self.batch_size].contiguous() + new_v_cache = new_v_cache[: self.batch_size].contiguous() + + return new_k_cache, new_v_cache + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + qkv_format: str, + ): + """ + Copy new KV tokens to the cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + qkv_format: str + Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + cu_seqlens_q: torch.Tensor + Updated cumulative sequence lengths for query, [batch_size + 1] + cu_seqlens_kv: torch.Tensor + Updated cumulative sequence lengths for key and value, [batch_size + 1] + max_seqlen_q: int + Update maximum sequence length for query + max_seqlen_kv: int + Update maximum sequence length for key and value + qkv_format: str + Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step() + """ + self.input_qkv_format = qkv_format + if self.input_qkv_format == self.cache_qkv_format: + self.output_qkv_format = self.cache_qkv_format + else: + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + + k_cache, v_cache = self.cache_manager.step( + layer_number, + new_k, + new_v, + self.cu_seqlens_q, + self.cu_seqlens_kv, + qkv_format, + ) + + return ( + k_cache, + v_cache, + self.cu_seqlens_q, + self.cu_seqlens_kv, + self.max_seqlen_kv, + self.output_qkv_format, + ) + + +class NonPagedKVCacheManager(KVCacheManager): + """Non-paged KV cache manager""" + + def __init__( + self, + max_batch_size: int, + max_seqlen: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + ): + super().__init__() + """Initialize cache manager""" + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + + # track sequences in the cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # track sequence indices in the batch in order to re-index k_cache and v_cache + self.batch_indices = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + # after re-indexing, batch indices are always [0, ..., b-1] + self.batch_indices_post_step = torch.range( + 0, + self.max_batch_size - 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def allocate_memory(self, layer_number): + """Allocate memory for the cache""" + k_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + # Track unfinished sequences' indices in the batch, e.g. + # at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished + # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that + # they are contiguous and match the indexing in q + prev_batch_size = len(self.sequences) + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] + finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] + self.batch_indices.copy_( + torch.Tensor( + ( + unfinished_indices + + finished_indices + + list(range(prev_batch_size, self.max_batch_size)) + ) + ).to(dtype=torch.int32, device="cpu") + ) + + # Advance unfinished sequences + for i in unfinished_seqs: + self.sequences[i] += 1 + + # Remove finished sequences + for i in finished_seqs: + self.sequences.pop(i) + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for i in new_seqs: + self.sequences[i] = step_dict[i] + + return self.sequences + + def step( + self, + layer_number, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, + qkv_format: str, + ): + """ + Copy the new tokens to the non-paged KV cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + qkv_format: str + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + """ + k_cache, v_cache = self.cache[layer_number] + + batch_size = self.max_batch_size + ctx_len = 1 + if qkv_format == "bshd": + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] + if qkv_format == "sbhd": + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + + tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + self.batch_indices, + cu_new_seqlens, + cu_cached_seqlens, + QKVFormat[qkv_format], + batch_size, + ctx_len, + self.max_seqlen, + 1, + True, + ) + + k_cache = k_cache[:batch_size] + v_cache = v_cache[:batch_size] + + return k_cache, v_cache + + +class Page: + """A single page""" + + def __init__(self, page_id: int): + """Initialize a page""" + self.page_id = page_id + self.allocated = 0 + + def allocate_page(self): + """Allocate a page""" + self.allocated = True + + def deallocate_page(self): + """Deallocate a page""" + self.allocated = False + + +class PagedKVCacheManager(KVCacheManager): + """Paged KV cache manager""" + + def __init__( + self, + total_num_pages: int, + page_size: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + max_batch_size: int, + max_seqlen: int, + head_dim_v: Optional[int] = None, + ): + super().__init__() + """Initialize cache manager""" + self.total_num_pages = total_num_pages + self.page_size = page_size + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.max_pages_per_seq = max_seqlen // self.page_size + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + + # track sequences in the cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # available pages, [Page(),...] + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + # allocated pages, {seq_id: [page_id,...]} + self.allocated_pages = defaultdict(list) + # page table, [batch_size, max_pages_per_seq] + self.page_table = torch.zeros( + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + ) + + def reset(self): + """Reset cache manager state""" + self.sequences = OrderedDict() + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + self.allocated_pages = defaultdict(list) + self.page_table.fill_(0) + + def allocate_memory(self, layer_number): + """Allocate memory for the cache""" + k_cache = torch.zeros( + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.zeros( + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + def print_cache(self): + """Print KV cache status""" + used_pages = [self.get_page_count(seq) for seq in self.sequences] + logger = logging.getLogger("PagedKVCacheManager") + logger.debug("Cache status:") + logger.debug( + " total pages: %s (used %s, free %s)", + self.total_num_pages, + sum(used_pages), + len(self.free_pages), + ) + logger.debug(" total sequences: %s", self.get_sequence_count()) + for i, seq in enumerate(self.sequences): + logger.debug( + " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", + i, + seq, + self.get_sequence_lengths()[i], + self.get_page_count(seq), + self.get_page_list(seq), ) + + def get_sequence_count(self): + """Get the total number of sequences in the KV cache""" + return len(self.sequences) + + def get_sequence_lengths(self): + """Get the list of sequence lengths in the KV cache""" + return list(self.sequences.values()) + + def has_free_page(self) -> bool: + """Whether the page pool has any free pages left""" + return len(self.free_pages) > 0 + + def get_page_count(self, seq: int): + """Get the number of pages allocated to a sequence""" + return len(self.allocated_pages[seq]) + + def get_page_list(self, seq: int): + """Get the list of pages allocated to a sequence""" + return [x.page_id for x in self.allocated_pages[seq]] + + def get_page_table(self, sequences: List[int]): + """Get the page table, in shape [batch_size, max_pages_per_seq]""" + page_table = torch.Tensor( + [ + self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) + for seq in sequences + ] + ).to(dtype=torch.int32, device="cpu") + self.page_table[: self.get_sequence_count()].copy_(page_table) + return self.page_table + + def allocate_page(self, seq: int): + """Allocate a new page to a sequence""" + if not self.has_free_page(): + raise RuntimeError("KV cache is full!") + page = self.free_pages.pop(0) + page.allocate_page() + self.allocated_pages[seq].append(page) + + def allocate_sequence(self, seq: int, context_len: int): + """Add a new sequence to the cache""" + num_pages = context_len // self.page_size + if context_len % self.page_size > 0: + num_pages = num_pages + 1 + for _ in range(num_pages): + self.allocate_page(seq) + + def deallocate_sequence(self, seq: int): + """Deallocate all the pages for a sequence""" + for page in self.allocated_pages[seq]: + page.deallocate_page() + if not page.allocated: + self.free_pages.append(page) + self.allocated_pages.pop(seq) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + # Remove finished sequences and advance unfinished sequences + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + for seq in finished_seqs: + self.sequences.pop(seq) + self.deallocate_sequence(seq) + for seq in unfinished_seqs: + if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: + self.allocate_page(seq) + self.sequences[seq] += 1 + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for seq in new_seqs: + self.sequences[seq] = step_dict[seq] + self.allocate_sequence(seq, step_dict[seq]) + + # Get page table + self.page_table = self.get_page_table(list(self.sequences.keys())) + + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, + qkv_format: str, + ): + """ + Copy the new tokens to the paged KV cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + qkv_format: str + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + """ + k_cache, v_cache = self.cache[layer_number] + + batch_size = self.max_batch_size + ctx_len = 1 + if qkv_format == "bshd": + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] + if qkv_format == "sbhd": + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + + tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + self.page_table, + cu_new_seqlens, + cu_cached_seqlens, + QKVFormat[qkv_format], + batch_size, + ctx_len, + self.max_seqlen, + self.max_pages_per_seq, + False, + ) + + return k_cache, v_cache diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py index a4424d9d38..a4f05334e1 100644 --- a/transformer_engine/pytorch/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -34,6 +34,7 @@ META_O_CP, META_DQKV_CP, ) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -102,6 +103,7 @@ class FlashAttentionUtils: v2_3_plus = False v2_4_plus = False v2_4_1_plus = False + v2_5_plus = False v2_5_7_plus = False v2_6_0_plus = False v2_7_0_plus = False @@ -110,13 +112,14 @@ class FlashAttentionUtils: fa3_version = PkgVersion("0") v3_0_0_beta = False use_v3 = False - # TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved - # https://github.com/Dao-AILab/flash-attention/issues/1452 + # FA3 from FA 2.7.3+/hopper has different APIs than FA3 from 2.7.2/hopper + # Please follow these instructions to install FA3 v3_installation_steps = """\ - (1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" - (2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` - (3) mkdir -p $python_path/flashattn_hopper - (4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" + (1) git clone https://github.com/Dao-AILab/flash-attention.git + (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install + (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` + (4) mkdir -p $python_path/flash_attn_3 + (5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" @staticmethod def set_flash_attention_version(): @@ -129,13 +132,12 @@ def set_flash_attention_version(): FlashAttentionUtils.v2_3_plus = FlashAttentionUtils.version >= PkgVersion("2.3") FlashAttentionUtils.v2_4_plus = FlashAttentionUtils.version >= PkgVersion("2.4") FlashAttentionUtils.v2_4_1_plus = FlashAttentionUtils.version >= PkgVersion("2.4.1") + FlashAttentionUtils.v2_5_plus = FlashAttentionUtils.version >= PkgVersion("2.5.0") FlashAttentionUtils.v2_5_7_plus = FlashAttentionUtils.version >= PkgVersion("2.5.7") FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0") FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0") - # Detect flash-attn v3 in the environment - # This section will be removed when FA3 is released as a regular FA package, - # i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0 + # Detect flash-attn v3 in the environment (Hopper only) @staticmethod def set_flash_attention_3_params(): """ @@ -145,7 +147,6 @@ def set_flash_attention_3_params(): FlashAttentionUtils.v3_0_0_beta = ( PkgVersion("3.0.0b") < FlashAttentionUtils.fa3_version < PkgVersion("3.0.0") ) - FlashAttentionUtils.use_v3 = True @dataclass(eq=True) @@ -203,6 +204,8 @@ class AttentionParams: Whether `DotProductAttention` is in an `fp8_autocast` region. fp8_meta: Optional[Dict[str Any]], default = `None` The FP8 metadata tensor of `DotProductAttention`. + inference_params: Optional[InferenceParams], default = `None` + Inference-related parameters. See InferenceParams for details. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -228,6 +231,7 @@ class AttentionParams: is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + inference_params: Optional[InferenceParams] = None def __eq__(self, other): """ @@ -298,6 +302,7 @@ def get_attention_backend( is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta + inference_params = attention_params.inference_params # Run config logger = logging.getLogger("DotProductAttention") @@ -334,13 +339,19 @@ def get_attention_backend( # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. + qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) # Filter: Environment variables use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_flash_attention_2 = use_flash_attention + use_flash_attention_3 = use_flash_attention + flash_attention_backend = None use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if not use_flash_attention and FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0") if not use_fused_attention: logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") if not use_unfused_attention: @@ -348,60 +359,124 @@ def get_attention_backend( # Filter: Compute capability if device_compute_capability < (8, 0): - if use_flash_attention and FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention as it requires compute capability sm80+") - use_flash_attention = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 for compute capability < sm80") + use_flash_attention_2 = False if use_fused_attention: - logger.debug("Disabling FusedAttention as it requires compute capability sm80+") + logger.debug("Disabling FusedAttention for compute capability < sm80") use_fused_attention = False - if device_compute_capability < (9, 0): - if use_flash_attention and FlashAttentionUtils.v3_is_installed: - logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") - FlashAttentionUtils.use_v3 = False + if device_compute_capability != (9, 0): + if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for compute capability != sm90") + use_flash_attention_3 = False # Filter: Data type - if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ + if qkv_dtype not in [torch.bfloat16, torch.float16]: + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for unsupported qkv_dtype = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. ", + qkv_dtype, + ) + use_flash_attention_2 = False + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ torch.Tensor, Float8Tensor, ]: - if use_flash_attention and FlashAttentionUtils.is_installed: + if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " + "qkv_type = {torch.Tensor, Float8Tensor}. ", qkv_dtype, + qkv_type, ) - use_flash_attention = False + use_flash_attention_3 = False if use_fused_attention: logger.debug( - "Disabling FusedAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " + "qkv_type = {torch.Tensor, Float8Tensor}. ", qkv_dtype, + qkv_type, ) use_fused_attention = False # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and not FlashAttentionUtils.use_v3: - if FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") - use_flash_attention = False - if use_flash_attention and FlashAttentionUtils.use_v3 and is_training: - logger.debug( - "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" - ) - use_flash_attention = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 for FP8 attention") + use_flash_attention_2 = False + if use_flash_attention_3 and is_training: + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for FP8 training") + use_flash_attention_3 = False if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False + # Filter: KV cache + # backend | precision | KV cache | architecture | qkv_format | page_size + # --------------------------------------------------------------------------------------- + # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 + # | FP8 | non-paged/paged | sm90 | thd | >= 1 + # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 + if inference_params is not None: + if context_parallel: + logger.debug("Disabling all backends for KV caching with context parallelism") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8_meta["recipe"].fp8_mha: + logger.debug("Disabling all backends for KV caching with FP8 MHA") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if use_flash_attention_3 and q_format != "thd": + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") + use_flash_attention_3 = False + if use_fused_attention: + logger.debug("Disabling FusedAttention for FP8 KV caching") + use_fused_attention = False + else: + if q_format == "thd" and pad_between_seqs: + logger.debug("Disabling all backends for pad_between_seqs = True and KV caching") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if inference_params.is_paged: + if use_flash_attention_2 and inference_params.page_size < 256: + if FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 for page size < 256") + use_flash_attention_2 = False + if use_flash_attention_2: + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.5") + elif not FlashAttentionUtils.v2_5_plus: + logger.debug( + "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" + ) + use_flash_attention_2 = False + # Filter: Head dimension - if use_flash_attention and head_dim_qk != head_dim_v: - if FlashAttentionUtils.is_installed: + if head_dim_qk != head_dim_v: + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): logger.debug("Disabling FlashAttention as it does not support MLA.") use_flash_attention = False - if use_flash_attention and ( + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False + if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 or ( @@ -411,7 +486,7 @@ def get_attention_backend( ): if FlashAttentionUtils.is_installed: logger.debug( - "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " "head_dim_qk <= 256 (>192 requires sm80/90/100+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", @@ -419,23 +494,21 @@ def get_attention_backend( head_dim_v, ".".join([str(i) for i in device_compute_capability]), ) - use_flash_attention = False - qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") - if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": - logger.debug( - "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", - qkv_layout, - ) - use_fused_attention = False + use_flash_attention_2 = False + if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for head_dim > 128") + use_flash_attention_3 = False # Filter: QKV layout - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) if qkv_format == "thd": if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") use_unfused_attention = False - if use_flash_attention and pad_between_seqs: - if FlashAttentionUtils.is_installed: + if pad_between_seqs: + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): logger.debug( "Disabling FlashAttention for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" @@ -443,9 +516,9 @@ def get_attention_backend( use_flash_attention = False # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention and FlashAttentionUtils.use_v3: + if attention_dropout != 0.0 and use_flash_attention_3: logger.debug("Disabling FlashAttention 3 for dropout") - FlashAttentionUtils.use_v3 = False + use_flash_attention_3 = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -464,42 +537,38 @@ def get_attention_backend( "Disabling UnfusedDotProductAttention as it does not support context parallelism" ) use_unfused_attention = False - if context_parallel and use_flash_attention: - if fp8 and fp8_meta["recipe"].fp8_dpa: - if FlashAttentionUtils.is_installed: + if context_parallel and (use_flash_attention_2 or use_flash_attention_3): + if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed: + if fp8 and fp8_meta["recipe"].fp8_dpa: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) - use_flash_attention = False - if "bottom_right" in attn_mask_type: - if FlashAttentionUtils.is_installed: + use_flash_attention = False + if "bottom_right" in attn_mask_type: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal_bottom_right masking" ) - use_flash_attention = False - elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - if FlashAttentionUtils.is_installed: + use_flash_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " causal masking for cross-attention" ) - use_flash_attention = False - elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - if FlashAttentionUtils.is_installed: + use_flash_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: logger.debug( "Disabling FlashAttention as it does not support context parallelism with bias" " type of %s", core_attention_bias_type, ) - use_flash_attention = False - elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - if FlashAttentionUtils.is_installed: + use_flash_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": logger.debug( "Disabling FlashAttention as it does not support context parallelism with" " attention bias for THD format" ) - use_flash_attention = False + use_flash_attention = False if context_parallel and use_fused_attention: if "bottom_right" in attn_mask_type: @@ -552,61 +621,25 @@ def get_attention_backend( # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": - if use_flash_attention and FlashAttentionUtils.is_installed: + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): logger.debug("Disabling FlashAttention for arbitrary mask") use_flash_attention = False if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False if ( - use_flash_attention - and FlashAttentionUtils.use_v3 + (use_flash_attention_2 or use_flash_attention_3) and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): logger.warning( - "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " + "Disabling FlashAttention as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1 (our minimum supported version). See " "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) - FlashAttentionUtils.use_v3 = False - if ( - use_flash_attention - and attn_mask_type in ["causal", "padding_causal"] - and max_seqlen_q != max_seqlen_kv - ): - if FlashAttentionUtils.v2_1_plus: - logger.warning( - "Disabling FlashAttention as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.max_version = PkgVersion("2.1") - if ( - use_flash_attention - and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] - and max_seqlen_q != max_seqlen_kv - ): - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.1") - elif not FlashAttentionUtils.v2_1_plus and not FlashAttentionUtils.use_v3: - logger.warning( - "Disabling FlashAttention as it only supports top-left-diagonal " - "causal mask before flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False - if ( - use_flash_attention - and FlashAttentionUtils.use_v3 - and fp8 - and fp8_meta["recipe"].fp8_dpa - and "padding" in attn_mask_type - ): - logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") - FlashAttentionUtils.use_v3 = False + use_flash_attention = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -637,19 +670,14 @@ def get_attention_backend( "with s_q > s_kv for cross-attention" ) use_fused_attention = False - if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if FlashAttentionUtils.use_v3: - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - FlashAttentionUtils.use_v3 = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if not FlashAttentionUtils.is_installed: FlashAttentionUtils.version_required = PkgVersion("2.3") elif not FlashAttentionUtils.v2_3_plus: logger.debug( "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" ) - use_flash_attention = False + use_flash_attention_2 = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -660,21 +688,25 @@ def get_attention_backend( # | | bottom_right (converts to a 'post_scale_bias' bias) # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias - if use_flash_attention and core_attention_bias_type == "alibi": - if FlashAttentionUtils.use_v3: - logger.debug("Disabling FlashAttention 3 for ALiBi") - FlashAttentionUtils.use_v3 = False - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.4") - elif not FlashAttentionUtils.v2_4_plus: - logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") - use_flash_attention = False + if core_attention_bias_type == "alibi": + if use_flash_attention_3: + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for ALiBi") + use_flash_attention_3 = False + if use_flash_attention_2: + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.4") + elif not FlashAttentionUtils.v2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention_2 = False - if use_flash_attention and ( + if ( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None ): - if FlashAttentionUtils.is_installed: + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): logger.debug("Disabling FlashAttention for pre/post_scale_bias") use_flash_attention = False @@ -779,16 +811,16 @@ def get_attention_backend( # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes - if use_flash_attention and deterministic: + if use_flash_attention_2 and deterministic: if not FlashAttentionUtils.is_installed: FlashAttentionUtils.version_required = PkgVersion("2.4.1") - elif not FlashAttentionUtils.v2_4_1_plus and not FlashAttentionUtils.use_v3: + elif not FlashAttentionUtils.v2_4_1_plus: logger.warning( "Disabling FlashAttention as version <2.4.1 does not support deterministic " "execution. To use FlashAttention with deterministic behavior, " "please install flash-attn >= 2.4.1." ) - use_flash_attention = False + use_flash_attention_2 = False if use_fused_attention and deterministic: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons") @@ -805,29 +837,46 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for determinism reasons") use_fused_attention = False - # All available backends - available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + # use_flash_attention may have been set above + use_flash_attention_2 = use_flash_attention and use_flash_attention_2 + use_flash_attention_3 = use_flash_attention and use_flash_attention_3 # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. # When `FusedAttention` does not support the provided attention params, and `FlashAttention` # does, we recommend users to install flash-attn if not installed already. - if not use_fused_attention and use_flash_attention and not FlashAttentionUtils.is_installed: - logger.warning( - "flash-attn may provide important feature support or performance improvement." - " Please install flash-attn %s.", - _get_supported_versions( - FlashAttentionUtils.version_required, - FlashAttentionUtils.max_version, - ), - ) - if use_flash_attention and not FlashAttentionUtils.is_installed: - use_flash_attention = False - available_backends[0] = False + if not use_fused_attention: + if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed: + logger.warning( + "flash-attn v3 may provide important feature support or performance improvement." + " Please install flash-attn v3 by \n%s", + FlashAttentionUtils.v3_installation_steps, + ) + elif use_flash_attention_2 and not FlashAttentionUtils.is_installed: + logger.warning( + "flash-attn may provide important feature support or performance improvement." + " Please install flash-attn %s by pip3 install flash-attn==.", + _get_supported_versions( + FlashAttentionUtils.version_required, + FlashAttentionUtils.max_version, + ), + ) + # All available backends + if use_flash_attention_2 and not FlashAttentionUtils.is_installed: + use_flash_attention_2 = False + if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed: + use_flash_attention_3 = False + use_flash_attention = use_flash_attention_2 or use_flash_attention_3 + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + if use_flash_attention_2: + flash_attention_backend = FlashAttentionUtils.version + if use_flash_attention_3: + flash_attention_backend = FlashAttentionUtils.fa3_version logger.debug( - "Available backends = {FlashAttention=%s, FusedAttention=%s%s," + "Available backends = {FlashAttention=%s%s, FusedAttention=%s%s," " UnfusedDotProductAttention=%s}", bool(available_backends[0]), + (f" ({str(flash_attention_backend)})" if flash_attention_backend is not None else ""), bool(available_backends[1]), ( f" (sub-backend {int(fused_attention_backend)})" @@ -838,26 +887,10 @@ def get_attention_backend( ) # Select FusedAttention for performance - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - ): - if device_compute_capability >= (9, 0): - logger.debug( - "Disabling FlashAttention to give FusedAttention preference on Hopper+ " - "for performance reasons" - ) - use_flash_attention = False - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["FP8"] - and FlashAttentionUtils.use_v3 - ): + if use_flash_attention and use_fused_attention and device_compute_capability >= (9, 0): logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " - "in FP8 execution" + "Disabling FlashAttention to give FusedAttention preference on Hopper+ " + "for performance reasons" ) use_flash_attention = False @@ -869,22 +902,16 @@ def get_attention_backend( use_unfused_attention = False selected_backend = "NoBackend" if use_flash_attention: - selected_backend = "FlashAttention" + selected_backend = f"FlashAttention ({str(flash_attention_backend)})" elif use_fused_attention: selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" elif use_unfused_attention: selected_backend = "UnfusedDotProductAttention" logger.debug("Selected backend = %s", selected_backend) - """global _attention_backends - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False""" - return ( use_flash_attention, + flash_attention_backend, use_fused_attention, fused_attention_backend, use_unfused_attention, @@ -892,6 +919,49 @@ def get_attention_backend( ) +@torch.no_grad() +def get_padding_mask( + batch_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_q: int, + max_seqlen_kv: int, +): + """Convert cu_seqlens to attention_mask""" + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) + for i in range(batch_size): + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor([False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i])) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask_kv = torch.cat( + [ + attention_mask_kv, + torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i])) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask = ( + attention_mask_q.to(device="cuda"), + attention_mask_kv.to(device="cuda"), + ) + return attention_mask + + @torch.no_grad() def get_full_mask( max_seqlen_q: int, @@ -1400,11 +1470,46 @@ def backward(ctx, grad_output): return None, None, _pack_tensor(indices, grad_output) +def get_qkv_format( + qkv_layout: str = "bshd_bshd_bshd", + inference_params: InferenceParams = None, +) -> str: + """Get qkv format. + + Parameters + ---------- + qkv_layout: str + Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details. + inference_params: InferenceParams, default = `None` + InferenceParams related to KV caching. + + Returns + ---------- + qkv_format: str, default = `sbhd` + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. + q_format: str + Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}. + kv_format: str + Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}. + """ + splited = qkv_layout.replace("paged_kv_", "").split("_") + if inference_params is not None: + q_format = "".join([i for i in splited[0] if i.isalpha()]) + kv_format = "".join([i for i in splited[1] if i.isalpha()]) + qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format + else: + qkv_format = "".join([i for i in splited[0] if i.isalpha()]) + q_format = qkv_format + kv_format = qkv_format + return qkv_format, q_format, kv_format + + def get_qkv_layout( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_format: str = "sbhd", + inference_params: InferenceParams = None, ) -> str: """Get qkv layout. @@ -1421,20 +1526,33 @@ def get_qkv_layout( the sequence length dimension, `b` batch size, `h` the number of attention heads, `d` head size, and `t` the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. + inference_params: InferenceParams, default = `None` + InferenceParams related to KV caching. Returns ---------- qkv_layout: str - Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five - memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk - of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means - `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v` - are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and - `v = kv[:,:,:,1,:]`. + Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and + `kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that + paged KV caching is in play. A few examples of the layouts are as follows. + + (1) `sb3hd` means `q`, `k`, `v` are created as one chunk of memory and that they are + interleaved in the `2`nd dimension. (2) `sbhd_sbh2d` means `q` and `kv` are created in + two chunks and that `q` itself is contiguous and `k`, `v` are interleaved with each other + in the `3`rd dimension, `k = kv[:,:,:,0,:]` and `v = kv[:,:,:,1,:]`. `q_format` and + `kv_format` in this case are still both `sbhd`. (3) `paged_kv_thd_bshd_bshd` means `q` is + created in `thd` and `k`, `v` are in `sbhd`. This is likely due to the cache format in + paged KV caching. + Mapping: - `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} - `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} + `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`, `paged_kv_sbhd_sbhd_sbhd`} + `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`, `paged_kv_bshd_bshd_bshd`} `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + `sbhd_2bshd`: {`sbhd_bshd_bshd`, `paged_kv_sbhd_bshd_bshd`} + `bshd_2sbhd`: {`bshd_sbhd_sbhd`, `paged_kv_bshd_sbhd_sbhd`} + `thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`} + `thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`} + q: torch.Tensor Query tensor. It may be different from input `q` as we try to fit tensors to a supported layout. @@ -1444,10 +1562,21 @@ def get_qkv_layout( v: torch.Tensor Value tensor. It may be different from input `v` as we try to fit tensors to a supported layout. + q_format: str + Format of the query tensor, {`bshd`, `sbhd`, `thd`}. + kv_format: str + Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. """ check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" + if "_2" in qkv_format: + q_format, kv_format = qkv_format.split("_2") + is_same_q_kv_format = False + else: + q_format = qkv_format + kv_format = qkv_format + is_same_q_kv_format = True def run_iteratively(q, k, v): # check data pointers @@ -1534,7 +1663,10 @@ def run_iteratively(q, k, v): # three chunks of memory, q, k and v, which may be disjoint or consecutive, and # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or # check_ptrs_qk=True or check_ptrs_kv=True - qkv_layout = "_".join(list([qkv_format]) * 3) + if is_same_q_kv_format: + qkv_layout = "_".join(list([qkv_format]) * 3) + else: + qkv_layout = q_format + "_" + kv_format + "_" + kv_format else: qkv_layout = "not_supported" @@ -1548,7 +1680,10 @@ def run_iteratively(q, k, v): if qkv_layout == "not_supported": raise RuntimeError("The provided qkv memory layout is not supported!") - return qkv_layout, q, k, v + if inference_params is not None and inference_params.is_paged: + qkv_layout = "paged_kv_" + qkv_layout + + return qkv_layout, q, k, v, q_format, kv_format def check_set_window_size( diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py deleted file mode 100644 index 46e961b381..0000000000 --- a/transformer_engine/pytorch/inference.py +++ /dev/null @@ -1,794 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Inference""" -import logging -from collections import OrderedDict, defaultdict -from typing import Optional, List -from einops import rearrange - -import torch - -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat - -__all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"] - - -class KVCacheManager: - """Base KV cache manager""" - - def __init__(self): - """Initialize cache manager""" - self.cache = {} - self.sequences = OrderedDict() - - def reset(self): - """Reset cache manager state""" - self.sequences = OrderedDict() - - def allocate_memory(self, layer_number: int): - """Allocate memory for the cache""" - self.cache[layer_number] = (None, None) - - def pre_step( - self, - step_dict: OrderedDict, # pylint: disable=unused-argument - ): - """Update tracked sequences and prepare for step()""" - return self.sequences - - def step( - self, - layer_number: int, - new_k: torch.Tensor, # pylint: disable=unused-argument - new_v: torch.Tensor, # pylint: disable=unused-argument - cu_new_seqlens: torch.Tensor, # pylint: disable=unused-argument - cu_cached_seqlens: torch.Tensor, # pylint: disable=unused-argument - qkv_format: str, # pylint: disable=unused-argument - ): - """Copy the new tokens to KV cache""" - return self.cache[layer_number] - - -class InferenceParams: - """ - KV caching for inference. The memory allocation of the caches and the copying of new tokens - to the cache take place at the following locations.:: - - class TransformerLayer: - class MultiHeadAttention: - if self.layer_number not in inference_params.cache_manager.cache: - inference_params.allocate_memory(self.layer_number) - class DotProductAttention: - if inference_params is not None: - k_cache, v_cache, new_qkv_format = inference_params.step( - new_k, new_v, qkv_format) - output = attention(new_q, k_cache, v_cache, new_qkv_format) - - allocate_memory() can be called outside the model, independently. step() can take three formats, - qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both - NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the - backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}. - A standard KV caching workflow for inference is as follows.:: - - model = [TransformerLayer() for _ in range(num_layers)] - # initialize InferenceParams, e.g. with PagedKVCacheManager - inference_params = InferenceParams(..., is_paged=True) - # inference loop - for i in range(num_iters): - # get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1] - step_dict = OrderedDict(zip(seq_ids, step_lens)) - # update inference_params' state - inference_params.pre_step(step_dict) - # run iteration - output = model( - ..., - attn_mask_type="padding_causal", - cu_seqlens_q=cu_seqlens_new_q, - cu_seqlens_kv=cu_seqlens_new_kv, - inference_params=inference_params, - ) - # get output tokens based on qkv_format - # 'bshd': output = output[:,step_dict.values()-1] - # 'sbhd': output = output[step_dict.values()-1,:] - # 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1 - - - Parameters - ---------- - max_batch_size: int - Maximum batch size in inference - max_seqlen_kv: int - Maximum sequence length in inference - num_heads_kv: int - Number of attention heads in keys and values - head_dim_k: int - Head size for keys - dtype: torch.dtype - Data type of the KV cache - head_dim_v: int, default = None - Head size for values. If None, initialized as head_dim_k. - is_paged: bool, default = False - Whether the KV cache is paged (True) or non-paged (False) - total_num_pages: int, default = None - Total number of pages in the KV cache. Required for is_paged = True. - page_size: int, default = None - Page size of the KV cache. Required for is_paged = True. - max_ctx_len: int, default = None - Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. - qkv_format: str, default = "bshd" - Format of the incoming query/key/value tensors in current iteration - cache_manager: KVCacheManager, default = None - Custom cache manager, with KVCacheManager as the base class. - """ - - def __init__( - self, - max_batch_size: int, - max_seqlen_kv: int, - num_heads_kv: int, - head_dim_k: int, - dtype: torch.dtype, - head_dim_v: int = None, - is_paged: bool = False, - total_num_pages: int = None, - page_size: int = None, - max_ctx_len: int = None, - qkv_format: str = "bshd", - cache_manager: KVCacheManager = None, - ): - self.max_batch_size = max_batch_size - self.max_seqlen_kv = max_seqlen_kv - self.num_heads_kv = num_heads_kv - self.head_dim_k = head_dim_k - self.dtype = dtype - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - self.is_paged = is_paged - - if not self.is_paged: - cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager - self.cache_manager = cls( - max_batch_size=self.max_batch_size, - max_seqlen=self.max_seqlen_kv, - num_heads=self.num_heads_kv, - head_dim_k=self.head_dim_k, - dtype=self.dtype, - head_dim_v=self.head_dim_v, - ) - else: - assert page_size is not None, "Paged KV cache requires page_size is not None." - self.page_size = page_size - assert ( - max_seqlen_kv % page_size == 0 - ), "Paged KV cache requires max_seqlen_kv % page_size = 0." - max_pages_per_seq = max_seqlen_kv // page_size - assert ( - total_num_pages == self.max_batch_size * max_pages_per_seq - ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." - self.total_num_pages = total_num_pages - - cls = cache_manager if cache_manager is not None else PagedKVCacheManager - self.cache_manager = cls( - total_num_pages=self.total_num_pages, - page_size=self.page_size, - num_heads=self.num_heads_kv, - head_dim_k=self.head_dim_k, - dtype=self.dtype, - max_batch_size=self.max_batch_size, - max_seqlen=self.max_seqlen_kv, - head_dim_v=self.head_dim_v, - ) - - if qkv_format == "thd": - assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" - self.max_ctx_len = max_ctx_len - - self.cache_qkv_format = "bshd" - self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format: - self.output_qkv_format = self.cache_qkv_format - else: - self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - - self.sequences_pre_step = OrderedDict() - self.sequences = OrderedDict() - self.batch_size = 0 - - self.cu_seqlens_q = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - self.cu_seqlens_kv = torch.zeros( - self.max_batch_size + 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - - def reset(self): - """Reset InferenceParams state""" - self.sequences = OrderedDict() - self.cache_manager.reset() - - def __repr__(self) -> str: - if self.is_paged: - return ( - f"dtype={self.dtype}, " - f"is_paged={self.is_paged}, " - f"total_pages={self.total_num_pages}, " - f"page_size={self.page_size}, " - f"num_heads={self.num_heads_kv}, " - f"head_dim_k={self.head_dim_k}, " - f"head_dim_v={self.head_dim_v}" - ) - return ( - f"dtype={self.dtype}, " - f"is_paged={self.is_paged}, " - f"max_batch_size={self.max_batch_size}, " - f"max_seqlen={self.max_seqlen_kv}, " - f"num_heads={self.num_heads_kv}, " - f"head_dim_k={self.head_dim_k}, " - f"head_dim_v={self.head_dim_v}" - ) - - def allocate_memory(self, layer_number: int): - """ - Allocate memory for the cache. For layer layer_number, - - NonPagedKVCacheManager: - - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] - - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] - - PagedKVCacheManager: - - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] - - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] - """ - self.cache_manager.allocate_memory(layer_number) - - def pre_step( - self, - step_dict: OrderedDict, - ): - """Update tracked sequences and prepare for step()""" - self.batch_size = len(step_dict) - - self.sequences = self.cache_manager.pre_step(step_dict) - # track the pre-step seqlens for the next layer in the model - self.sequences_pre_step = OrderedDict() - for k, v in self.sequences.items(): - self.sequences_pre_step[k] = v - step_dict[k] - - seqlens_q = list(step_dict.values()) - cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] - cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) - self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu")) - - seqlens_kv = list(self.sequences.values()) - cu_seqlens_kv = [0] + [sum(seqlens_kv[:i]) for i in range(1, self.batch_size + 1)] - cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( - self.max_batch_size - self.batch_size - ) - self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu")) - - def get_seqlens_pre_step(self): - """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to( - dtype=torch.int32, device="cpu" - ) - - def convert_paged_to_nonpaged(self, layer_number: int): - """ - Convert k_cache and v_cache from paged to non-paged format. - - Parameters - ---------- - layer_number: int - Layer number of attention in the model - - Returns - ------- - k_cache: torch.Tensor - Non-paged key cache tensor - v_cache: torch.Tensor - Non-paged value cache tensor - """ - k_cache, v_cache = self.cache_manager.cache[layer_number] - page_table = self.cache_manager.page_table - batch_size = page_table.shape[0] - new_k_cache = rearrange( - k_cache[page_table.flatten()], - "(b npages) page_size ... -> b (npages page_size) ...", - b=batch_size, - ) - new_v_cache = rearrange( - v_cache[page_table.flatten()], - "(b npages) page_size ... -> b (npages page_size) ...", - b=batch_size, - ) - - new_k_cache = new_k_cache[: self.batch_size].contiguous() - new_v_cache = new_v_cache[: self.batch_size].contiguous() - - return new_k_cache, new_v_cache - - def step( - self, - layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - qkv_format: str, - ): - """ - Copy new KV tokens to the cache. - - Parameters - ---------- - layer_number: int - Layer number of attention in the model - new_k: torch.Tensor - New key tokens for layer_number in current inference iteration - new_v: torch.Tensor - New value tokens for layer_number in current inference iteration - qkv_format: str - Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - Full key tensor containing both previous and current key tokens - v_cache: torch.Tensor - Full value tensor containing both previous and current value tokens - cu_seqlens_q: torch.Tensor - Updated cumulative sequence lengths for query, [batch_size + 1] - cu_seqlens_kv: torch.Tensor - Updated cumulative sequence lengths for key and value, [batch_size + 1] - max_seqlen_q: int - Update maximum sequence length for query - max_seqlen_kv: int - Update maximum sequence length for key and value - qkv_format: str - Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step() - """ - self.input_qkv_format = qkv_format - if self.input_qkv_format == self.cache_qkv_format: - self.output_qkv_format = self.cache_qkv_format - else: - self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format - - k_cache, v_cache = self.cache_manager.step( - layer_number, - new_k, - new_v, - self.cu_seqlens_q, - self.cu_seqlens_kv, - qkv_format, - ) - - return ( - k_cache, - v_cache, - self.cu_seqlens_q, - self.cu_seqlens_kv, - self.max_seqlen_kv, - self.output_qkv_format, - ) - - -class NonPagedKVCacheManager(KVCacheManager): - """Non-paged KV cache manager""" - - def __init__( - self, - max_batch_size: int, - max_seqlen: int, - num_heads: int, - head_dim_k: int, - dtype: torch.dtype, - head_dim_v: Optional[int] = None, - ): - super().__init__() - """Initialize cache manager""" - self.max_batch_size = max_batch_size - self.max_seqlen = max_seqlen - self.num_heads = num_heads - self.head_dim_k = head_dim_k - self.dtype = dtype - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - - # track sequences in the cache, {seq_id: seq_len} - self.sequences = OrderedDict() - # cache tensors, cache[layer_number] = (k_cache, v_cache) - self.cache = {} - # track sequence indices in the batch in order to re-index k_cache and v_cache - self.batch_indices = torch.zeros( - self.max_batch_size, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - # after re-indexing, batch indices are always [0, ..., b-1] - self.batch_indices_post_step = torch.range( - 0, - self.max_batch_size - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - - def allocate_memory(self, layer_number): - """Allocate memory for the cache""" - k_cache = torch.zeros( - self.max_batch_size, - self.max_seqlen, - self.num_heads, - self.head_dim_k, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - v_cache = torch.zeros( - self.max_batch_size, - self.max_seqlen, - self.num_heads, - self.head_dim_v, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - self.cache[layer_number] = (k_cache, v_cache) - - def pre_step( - self, - step_dict: OrderedDict, - ): - """Update tracked sequences and prepare for step()""" - # Track unfinished sequences' indices in the batch, e.g. - # at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished - # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that - # they are contiguous and match the indexing in q - prev_batch_size = len(self.sequences) - unfinished_seqs = self.sequences.keys() & step_dict.keys() - finished_seqs = self.sequences.keys() - unfinished_seqs - unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] - finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( - torch.Tensor( - ( - unfinished_indices - + finished_indices - + list(range(prev_batch_size, self.max_batch_size)) - ) - ).to(dtype=torch.int32, device="cpu") - ) - - # Advance unfinished sequences - for i in unfinished_seqs: - self.sequences[i] += 1 - - # Remove finished sequences - for i in finished_seqs: - self.sequences.pop(i) - - # Add new sequences - new_seqs = step_dict.keys() - self.sequences.keys() - for i in new_seqs: - self.sequences[i] = step_dict[i] - - return self.sequences - - def step( - self, - layer_number, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_new_seqlens, - cu_cached_seqlens, - qkv_format: str, - ): - """ - Copy the new tokens to the non-paged KV cache. - - Parameters - ---------- - layer_number: int - Layer number of attention in the model - new_k: torch.Tensor - New key tokens for layer_number in current inference iteration - new_v: torch.Tensor - New value tokens for layer_number in current inference iteration - cu_new_seqlens: torch.Tensor - Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] - cu_cached_seqlens: torch.Tensor - Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] - qkv_format: str - Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - Full key tensor containing both previous and current key tokens - v_cache: torch.Tensor - Full value tensor containing both previous and current value tokens - """ - k_cache, v_cache = self.cache[layer_number] - - batch_size = self.max_batch_size - ctx_len = 1 - if qkv_format == "bshd": - batch_size = new_k.shape[0] - ctx_len = new_k.shape[1] - if qkv_format == "sbhd": - batch_size = new_k.shape[1] - ctx_len = new_k.shape[0] - - tex.copy_to_kv_cache( - new_k, - new_v, - k_cache, - v_cache, - self.batch_indices, - cu_new_seqlens, - cu_cached_seqlens, - QKVFormat[qkv_format], - batch_size, - ctx_len, - self.max_seqlen, - 1, - True, - ) - - k_cache = k_cache[:batch_size] - v_cache = v_cache[:batch_size] - - return k_cache, v_cache - - -class Page: - """A single page""" - - def __init__(self, page_id: int): - """Initialize a page""" - self.page_id = page_id - self.allocated = 0 - - def allocate_page(self): - """Allocate a page""" - self.allocated = True - - def deallocate_page(self): - """Deallocate a page""" - self.allocated = False - - -class PagedKVCacheManager(KVCacheManager): - """Paged KV cache manager""" - - def __init__( - self, - total_num_pages: int, - page_size: int, - num_heads: int, - head_dim_k: int, - dtype: torch.dtype, - max_batch_size: int, - max_seqlen: int, - head_dim_v: Optional[int] = None, - ): - super().__init__() - """Initialize cache manager""" - self.total_num_pages = total_num_pages - self.page_size = page_size - self.num_heads = num_heads - self.head_dim_k = head_dim_k - self.dtype = dtype - self.max_batch_size = max_batch_size - self.max_seqlen = max_seqlen - self.max_pages_per_seq = max_seqlen // self.page_size - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k - - # track sequences in the cache, {seq_id: seq_len} - self.sequences = OrderedDict() - # cache tensors, cache[layer_number] = (k_cache, v_cache) - self.cache = {} - # available pages, [Page(),...] - self.free_pages = [] - for i in range(self.total_num_pages): - self.free_pages.append(Page(i)) - # allocated pages, {seq_id: [page_id,...]} - self.allocated_pages = defaultdict(list) - # page table, [batch_size, max_pages_per_seq] - self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" - ) - - def reset(self): - """Reset cache manager state""" - self.sequences = OrderedDict() - self.free_pages = [] - for i in range(self.total_num_pages): - self.free_pages.append(Page(i)) - self.allocated_pages = defaultdict(list) - self.page_table.fill_(0) - - def allocate_memory(self, layer_number): - """Allocate memory for the cache""" - k_cache = torch.zeros( - self.total_num_pages, - self.page_size, - self.num_heads, - self.head_dim_k, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - v_cache = torch.zeros( - self.total_num_pages, - self.page_size, - self.num_heads, - self.head_dim_v, - dtype=self.dtype, - device=torch.cuda.current_device(), - ) - self.cache[layer_number] = (k_cache, v_cache) - - def print_cache(self): - """Print KV cache status""" - used_pages = [self.get_page_count(seq) for seq in self.sequences] - logger = logging.getLogger("PagedKVCacheManager") - logger.debug("Cache status:") - logger.debug( - " total pages: %s (used %s, free %s)", - self.total_num_pages, - sum(used_pages), - len(self.free_pages), - ) - logger.debug(" total sequences: %s", self.get_sequence_count()) - for i, seq in enumerate(self.sequences): - logger.debug( - " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", - i, - seq, - self.get_sequence_lengths()[i], - self.get_page_count(seq), - self.get_page_list(seq), - ) - - def get_sequence_count(self): - """Get the total number of sequences in the KV cache""" - return len(self.sequences) - - def get_sequence_lengths(self): - """Get the list of sequence lengths in the KV cache""" - return list(self.sequences.values()) - - def has_free_page(self) -> bool: - """Whether the page pool has any free pages left""" - return len(self.free_pages) > 0 - - def get_page_count(self, seq: int): - """Get the number of pages allocated to a sequence""" - return len(self.allocated_pages[seq]) - - def get_page_list(self, seq: int): - """Get the list of pages allocated to a sequence""" - return [x.page_id for x in self.allocated_pages[seq]] - - def get_page_table(self, sequences: List[int]): - """Get the page table, in shape [batch_size, max_pages_per_seq]""" - page_table = torch.Tensor( - [ - self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) - for seq in sequences - ] - ).to(dtype=torch.int32, device="cpu") - self.page_table[: self.get_sequence_count()].copy_(page_table) - return self.page_table - - def allocate_page(self, seq: int): - """Allocate a new page to a sequence""" - if not self.has_free_page(): - raise RuntimeError("KV cache is full!") - page = self.free_pages.pop(0) - page.allocate_page() - self.allocated_pages[seq].append(page) - - def allocate_sequence(self, seq: int, context_len: int): - """Add a new sequence to the cache""" - num_pages = context_len // self.page_size - if context_len % self.page_size > 0: - num_pages = num_pages + 1 - for _ in range(num_pages): - self.allocate_page(seq) - - def deallocate_sequence(self, seq: int): - """Deallocate all the pages for a sequence""" - for page in self.allocated_pages[seq]: - page.deallocate_page() - if not page.allocated: - self.free_pages.append(page) - self.allocated_pages.pop(seq) - - def pre_step( - self, - step_dict: OrderedDict, - ): - """Update tracked sequences and prepare for step()""" - # Remove finished sequences and advance unfinished sequences - unfinished_seqs = self.sequences.keys() & step_dict.keys() - finished_seqs = self.sequences.keys() - unfinished_seqs - for seq in finished_seqs: - self.sequences.pop(seq) - self.deallocate_sequence(seq) - for seq in unfinished_seqs: - if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: - self.allocate_page(seq) - self.sequences[seq] += 1 - - # Add new sequences - new_seqs = step_dict.keys() - self.sequences.keys() - for seq in new_seqs: - self.sequences[seq] = step_dict[seq] - self.allocate_sequence(seq, step_dict[seq]) - - # Get page table - self.page_table = self.get_page_table(list(self.sequences.keys())) - - return self.sequences - - def step( - self, - layer_number: int, - new_k: torch.Tensor, - new_v: torch.Tensor, - cu_new_seqlens, - cu_cached_seqlens, - qkv_format: str, - ): - """ - Copy the new tokens to the paged KV cache. - - Parameters - ---------- - layer_number: int - Layer number of attention in the model - new_k: torch.Tensor - New key tokens for layer_number in current inference iteration - new_v: torch.Tensor - New value tokens for layer_number in current inference iteration - cu_new_seqlens: torch.Tensor - Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] - cu_cached_seqlens: torch.Tensor - Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] - qkv_format: str - Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} - - Returns - ------- - k_cache: torch.Tensor - Full key tensor containing both previous and current key tokens - v_cache: torch.Tensor - Full value tensor containing both previous and current value tokens - """ - k_cache, v_cache = self.cache[layer_number] - - batch_size = self.max_batch_size - ctx_len = 1 - if qkv_format == "bshd": - batch_size = new_k.shape[0] - ctx_len = new_k.shape[1] - if qkv_format == "sbhd": - batch_size = new_k.shape[1] - ctx_len = new_k.shape[0] - - tex.copy_to_kv_cache( - new_k, - new_v, - k_cache, - v_cache, - self.page_table, - cu_new_seqlens, - cu_cached_seqlens, - QKVFormat[qkv_format], - batch_size, - ctx_len, - self.max_seqlen, - self.max_pages_per_seq, - False, - ) - - return k_cache, v_cache From b598cb989e1db2f6e20a77420fd23820a6d72507 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 18:53:54 -0700 Subject: [PATCH 225/239] fix merge 2 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 23 +++++++++---------- .../pytorch/dot_product_attention/utils.py | 10 ++++---- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 165a86b4f7..76587d485e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -77,7 +77,6 @@ # Import attention utils import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils -from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb @@ -4129,7 +4128,7 @@ def forward( context_parallel = cp_size > 1 # get q_format and kv_format for training and inference - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) # convert q, k, v to bshd if they are in sbhd; qkv_format doesn't change if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): @@ -5125,7 +5124,7 @@ def forward( context_parallel = cp_size > 1 # get q_format and kv_format for training and inference - qkv_format, q_format, kv_format = get_qkv_format(qkv_layout, inference_params) + qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) page_table = None if inference_params is None: @@ -5150,20 +5149,20 @@ def forward( "Please provide attention_mask or cu_seqlens for padding!" ) if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) else: if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_q, query_layer.device, ) if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_kv, key_layer.device, @@ -6046,9 +6045,9 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) else: cu_seqlens_q = dpa_utils.get_full_cu_seqlens( batch_size, @@ -6063,9 +6062,9 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - cu_seqlens_kv = get_cu_seqlens(attention_mask) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask) else: - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) else: cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( batch_size, diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py index a4f05334e1..8ce8595021 100644 --- a/transformer_engine/pytorch/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -115,11 +115,11 @@ class FlashAttentionUtils: # FA3 from FA 2.7.3+/hopper has different APIs than FA3 from 2.7.2/hopper # Please follow these instructions to install FA3 v3_installation_steps = """\ - (1) git clone https://github.com/Dao-AILab/flash-attention.git - (2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install - (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` - (4) mkdir -p $python_path/flash_attn_3 - (5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" +(1) git clone https://github.com/Dao-AILab/flash-attention.git +(2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install +(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(4) mkdir -p $python_path/flash_attn_3 +(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" @staticmethod def set_flash_attention_version(): From 5e4544203a9e0a63c8f38a2ab9099851349c7a82 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 19:22:57 -0700 Subject: [PATCH 226/239] fix FA import comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 4 +--- transformer_engine/pytorch/dot_product_attention/utils.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 76587d485e..f98acad02f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -85,7 +85,7 @@ # Setup Attention Logging attn_log.setup_logging() -# Global vars for flash attn imports +# Global vars for flash attn v2 and v3 imports flash_attn_cuda_bwd = None flash_attn_func = None flash_attn_varlen_func = None @@ -135,8 +135,6 @@ ), fa_utils.version, ) - -# Detect flash-attn v3 in the environment (Hopper only) try: _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py index 8ce8595021..5d939a370f 100644 --- a/transformer_engine/pytorch/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -92,7 +92,6 @@ class FlashAttentionUtils: Manage Flash Attention versioning information """ - # Detect flash-attn v2 in the environment is_installed = False version = PkgVersion("0") version_required = PkgVersion("2.1.1") @@ -137,7 +136,6 @@ def set_flash_attention_version(): FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0") FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0") - # Detect flash-attn v3 in the environment (Hopper only) @staticmethod def set_flash_attention_3_params(): """ From d77011628b4c10960eaeb643c35d8c3610a2491c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 19:41:31 -0700 Subject: [PATCH 227/239] relax tols for Ampere Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index e0b970620e..340366b69e 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -372,7 +372,7 @@ def generate_args( def get_tols(module, backend, dtype): if module == "TransformerLayer": tols = { - torch.half: (4e-3, 4e-3), + torch.half: (5e-3, 5e-3), torch.bfloat16: (3.5e-2, 3.5e-2), } if module == "DotProductAttention": From 0025478433d049ed4acf9bb21ac089736f8b5eac Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 14 Mar 2025 20:28:18 -0700 Subject: [PATCH 228/239] fix fa3 version and reduce messaging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- .../pytorch/dot_product_attention/utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f98acad02f..3a877a8d81 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -136,7 +136,7 @@ fa_utils.version, ) try: - _flash_attn_3_version = PkgVersion(get_pkg_version("flash-attn-3")) + fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py index 5d939a370f..329119e45e 100644 --- a/transformer_engine/pytorch/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -106,6 +106,7 @@ class FlashAttentionUtils: v2_5_7_plus = False v2_6_0_plus = False v2_7_0_plus = False + warning_printed = False v3_is_installed = False fa3_version = PkgVersion("0") @@ -119,6 +120,7 @@ class FlashAttentionUtils: (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 (5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" + v3_warning_printed = False @staticmethod def set_flash_attention_version(): @@ -842,14 +844,15 @@ def get_attention_backend( # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. # When `FusedAttention` does not support the provided attention params, and `FlashAttention` # does, we recommend users to install flash-attn if not installed already. - if not use_fused_attention: - if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed: + if not use_fused_attention and _NVTE_FLASH_ATTN: + if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed and not FlashAttentionUtils.v3_warning_printed and torch.cuda.current_device() == 0: logger.warning( "flash-attn v3 may provide important feature support or performance improvement." " Please install flash-attn v3 by \n%s", FlashAttentionUtils.v3_installation_steps, ) - elif use_flash_attention_2 and not FlashAttentionUtils.is_installed: + FlashAttentionUtils.v3_warning_printed = True + elif use_flash_attention_2 and not FlashAttentionUtils.is_installed and not FlashAttentionUtils.warning_printed and torch.cuda.current_device() == 0: logger.warning( "flash-attn may provide important feature support or performance improvement." " Please install flash-attn %s by pip3 install flash-attn==.", @@ -858,6 +861,7 @@ def get_attention_backend( FlashAttentionUtils.max_version, ), ) + FlashAttentionUtils.warning_printed = True # All available backends if use_flash_attention_2 and not FlashAttentionUtils.is_installed: use_flash_attention_2 = False From bec87e7ef4e7ae11682dce55bccc8ca3216f23f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Mar 2025 03:28:46 +0000 Subject: [PATCH 229/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/dot_product_attention/utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py index 329119e45e..d3dd8d8c23 100644 --- a/transformer_engine/pytorch/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -845,14 +845,24 @@ def get_attention_backend( # When `FusedAttention` does not support the provided attention params, and `FlashAttention` # does, we recommend users to install flash-attn if not installed already. if not use_fused_attention and _NVTE_FLASH_ATTN: - if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed and not FlashAttentionUtils.v3_warning_printed and torch.cuda.current_device() == 0: + if ( + use_flash_attention_3 + and not FlashAttentionUtils.v3_is_installed + and not FlashAttentionUtils.v3_warning_printed + and torch.cuda.current_device() == 0 + ): logger.warning( "flash-attn v3 may provide important feature support or performance improvement." " Please install flash-attn v3 by \n%s", FlashAttentionUtils.v3_installation_steps, ) FlashAttentionUtils.v3_warning_printed = True - elif use_flash_attention_2 and not FlashAttentionUtils.is_installed and not FlashAttentionUtils.warning_printed and torch.cuda.current_device() == 0: + elif ( + use_flash_attention_2 + and not FlashAttentionUtils.is_installed + and not FlashAttentionUtils.warning_printed + and torch.cuda.current_device() == 0 + ): logger.warning( "flash-attn may provide important feature support or performance improvement." " Please install flash-attn %s by pip3 install flash-attn==.", From 4a74ef8c66da29f202f7dadb73d435364ae57ec5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 18 Mar 2025 00:16:53 +0530 Subject: [PATCH 230/239] Add issue template (#1584) * Add issue template Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * Make GPU info section Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .github/ISSUE_TEMPLATE/bug_report.md | 47 +++++++++++++++++++++++ .github/ISSUE_TEMPLATE/feature_request.md | 25 ++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000..fef0dde9ba --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,47 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** + +A clear and concise description of what the bug is. + +**Steps/Code to reproduce bug** + +Please list *minimal* steps or code snippet for us to be able to reproduce the bug. + +A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports. + + +**Expected behavior** + +A clear and concise description of what you expected to happen. + +**Environment overview (please complete the following information)** + + - Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)] + - Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install. + - If method of install is [Docker], provide `docker pull` & `docker run` commands used + +**Environment details** + +If NVIDIA docker image is used you don't need to specify these. +Otherwise, please provide: +- OS version +- PyTorch version +- Python version +- Transformer Engine version +- CUDA version +- CUDNN version + +**Device details** +- GPU model + +**Additional context** + +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000..355e553939 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,25 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: feature request +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** + +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** + +A clear and concise description of what you want to happen. +Provide a code snippet on how new APIs/changes would be used by others. + +**Describe alternatives you've considered** + +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** + +Add any other context or screenshots about the feature request here. From 7ddc59323d6b6f1bcf4f5023b26d4383fac91889 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 17 Mar 2025 11:47:56 -0700 Subject: [PATCH 231/239] Better cuBLAS handle management (#1389) * Do not create multiple cublas handle Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix for multiple GPUs per thread Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix multithreaded execution Signed-off-by: Przemek Tredak * Fix from conlfict Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/common/cudnn_utils.cpp | 12 +++-- transformer_engine/common/cudnn_utils.h | 30 ++++------- .../common/fused_attn/fused_attn.cpp | 12 ++--- .../common/gemm/cublaslt_gemm.cu | 10 +++- .../common/normalization/common.cpp | 2 +- .../common/util/handle_manager.h | 52 +++++++++++++++++++ 6 files changed, 84 insertions(+), 34 deletions(-) create mode 100644 transformer_engine/common/util/handle_manager.h diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index 80d2707315..eaf6de680a 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -57,9 +57,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) } } -void nvte_cudnn_handle_init() { - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); -} +void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } + +namespace detail { + +void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } + +} // namespace detail } // namespace transformer_engine @@ -68,6 +72,6 @@ namespace cudnn_frontend { // This is needed to define the symbol `cudnn_dlhandle` // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // to enable dynamic loading. -void *cudnn_dlhandle = nullptr; +void* cudnn_dlhandle = nullptr; } // namespace cudnn_frontend diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h index eb19b9ddb2..0016ad7f55 100644 --- a/transformer_engine/common/cudnn_utils.h +++ b/transformer_engine/common/cudnn_utils.h @@ -10,37 +10,25 @@ #include #include #include - -#include -#include +#include #include "transformer_engine/transformer_engine.h" +#include "util/handle_manager.h" namespace transformer_engine { -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); +namespace detail { -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); +void CreateCuDNNHandle(cudnnHandle_t* handle); -class cudnnExecutionPlanManager { - public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } +} // namespace detail - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); - ~cudnnExecutionPlanManager() {} +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); - private: - cudnnHandle_t handle_ = nullptr; -}; +using cudnnExecutionPlanManager = detail::HandleManager; } // namespace transformer_engine -#endif +#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 13c99ae244..a131aab7f3 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -329,7 +329,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -411,7 +411,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -511,7 +511,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -602,7 +602,7 @@ void nvte_fused_attn_bwd_kvpacked( t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -699,7 +699,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); @@ -786,7 +786,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 39b887783b..d24d114c29 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -14,6 +14,7 @@ #include #include "../common.h" +#include "../util/handle_manager.h" #include "../util/logging.h" #include "common/util/cuda_runtime.h" @@ -47,6 +48,10 @@ uint32_t _getAlignment(uintptr_t address) { } } +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + struct GemmParam { void *A; void *B; @@ -140,6 +145,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla namespace transformer_engine { +using cublasHandleManager = detail::HandleManager; + void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, @@ -192,8 +199,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, float zero = 0.0; float beta = (accumulate) ? one : zero; - cublasLtHandle_t handle; - NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 7ef3ac44e7..ddda78d951 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -211,7 +211,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); - _handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + _handle = cudnnExecutionPlanManager::Instance().GetHandle(); _graph.set_io_data_type(get_cudnn_fe_dtype(itype)) .set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h new file mode 100644 index 0000000000..adb2f55587 --- /dev/null +++ b/transformer_engine/common/util/handle_manager.h @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ +#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ + +#include + +#include "cuda_runtime.h" +#include "logging.h" + +namespace transformer_engine::detail { + +template +class HandleManager { + public: + static HandleManager& Instance() { + static thread_local HandleManager instance; + return instance; + } + + Handle GetHandle() { + static thread_local std::vector initialized(handles_.size(), false); + const int device_id = cuda::current_device(); + NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); + if (!initialized[device_id]) { + Create(&(handles_[device_id])); + initialized[device_id] = true; + } + return handles_[device_id]; + } + + ~HandleManager() { + if (Destroy != nullptr) { + for (auto& handle : handles_) { + Destroy(handle); + } + } + } + + private: + HandleManager() : handles_(cuda::num_devices(), nullptr) {} + + std::vector handles_ = nullptr; +}; + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ From cb2d56e555e504d52e238547f0ad41b43610f5ed Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 17 Mar 2025 12:57:05 -0700 Subject: [PATCH 232/239] update FA3 to its latest commit on main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L3_pytorch_FA_versions_test/test.sh | 6 +++--- tests/pytorch/fused_attn/test_paged_attn.py | 2 +- transformer_engine/pytorch/dot_product_attention/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index febc1aa1ad..a805cffafa 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ] then FA_versions=(2.7.3) else - FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) + FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) fi for fa_version in "${FA_versions[@]}" @@ -29,10 +29,10 @@ do pip3 install flash-attn==${fa_version} else git clone https://github.com/Dao-AILab/flash-attention.git - cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install + cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install python_path=`python -c "import site; print(site.getsitepackages()[0])"` mkdir -p $python_path/flash_attn_3 - wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py + wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py cd ../../ fi diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 340366b69e..f810f11195 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -379,7 +379,7 @@ def get_tols(module, backend, dtype): tols = { torch.half: (1e-3, 1e-3), torch.bfloat16: (1e-2, 1e-3), - torch.float8_e4m3fn: (1e-2, 3e-2), + torch.float8_e4m3fn: (2e-2, 3e-2), } return tols[dtype] diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py index d3dd8d8c23..bae237c592 100644 --- a/transformer_engine/pytorch/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -116,10 +116,10 @@ class FlashAttentionUtils: # Please follow these instructions to install FA3 v3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git -(2) cd flash-attention/ && git checkout 39e7197 && cd hopper/ && python setup.py install +(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 -(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/39e71975642daab365a5a37c959182c93ed5fc8a/hopper/flash_attn_interface.py""" +(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" v3_warning_printed = False @staticmethod From 6a855962e9f0582a1e6c0b0084d0fe6ad94872ab Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 17 Mar 2025 13:32:48 -0700 Subject: [PATCH 233/239] Distopt with offload (#1573) * DistOpt support with offloading Signed-off-by: Selvaraj Anandaraj * Added distopt support for TE2.0 Signed-off-by: Selvaraj Anandaraj * Restricted this to MCore DistOpt only Signed-off-by: Selvaraj Anandaraj * Added guards Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/module/layernorm_linear.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/module/layernorm_linear.py | 18 ++++++++++++++++-- transformer_engine/pytorch/module/linear.py | 19 ++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9c3c798e68..4022924861 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -383,6 +383,17 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_object = weight + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, weightmat, @@ -526,8 +537,11 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight.main_grad = main_grad + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + origin_weight = ctx.weight_object + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + origin_weight.main_grad = main_grad ctx.ub_obj_gradout = None ub_obj_dgrad = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 77b52dae26..f96355a678 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -291,6 +291,17 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_object = weight + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -392,9 +403,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) - weight.main_grad = main_grad + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + weight = ctx.weight_object + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already From c571c2fd4068bd60d6f8a4aa088afac478dcfbe1 Mon Sep 17 00:00:00 2001 From: linxiddd Date: Tue, 18 Mar 2025 05:52:35 +0800 Subject: [PATCH 234/239] [QA] Add error handling (#1570) * [QA] Add error handling -Standardize test failure handling using the unified 'test_fail' function and 'error_exit' function. Signed-off-by: Linxi Ding * Update script to use explicit python3, pip3, and python3 -m pytest calls - Change pip to pip3. - Change python to python3. - Change pytest to python3 -m pytest. Signed-off-by: Linxi Ding --------- Signed-off-by: Linxi Ding --- qa/L0_jax_distributed_unittest/test.sh | 29 ++++++++-- qa/L0_jax_unittest/test.sh | 37 +++++++++---- qa/L0_jax_wheel/test.sh | 47 ++++++++++++----- qa/L0_pytorch_unittest/test.sh | 61 ++++++++++++++-------- qa/L0_pytorch_wheel/test.sh | 47 ++++++++++++----- qa/L1_jax_distributed_unittest/test.sh | 2 +- qa/L1_pytorch_distributed_unittest/test.sh | 37 +++++++++---- qa/L3_pytorch_FA_versions_test/test.sh | 2 +- 8 files changed, 186 insertions(+), 76 deletions(-) diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 8bd8a236d8..3253861484 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -2,14 +2,33 @@ # # See LICENSE for license information. -set -xe +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} -pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 83343e42b5..1f7bb0ebc4 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -2,22 +2,41 @@ # # See LICENSE for license information. -set -xe +function error_exit() { + echo "Error: $1" + exit 1 +} -pip3 install "nltk>=3.8.2" -pip3 install pytest==8.2.1 +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + +pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${TE_PATH:=/opt/transformerengine} -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "test_praxis_layers.py" # Test without custom calls -NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py +NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" -pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt -pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt +pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 48254e24e1..e1400b10bd 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -2,34 +2,53 @@ # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : "${TE_PATH:=/opt/transformerengine}" -pip3 install wheel +pip3 install wheel || error_exit "Failed to install wheel" cd $TE_PATH -pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax +pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-jax" VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel -wheel unpack dist/* +NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/* || error_exit "Failed to unpack dist/*" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" +rm dist/*.whl || error_exit "Failed to remove dist/*.whl" +mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" cd transformer_engine/jax -NVTE_RELEASE_BUILD=1 python3 setup.py sdist +NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" -pip3 install dist/* +pip3 install dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip3 install dist/*.whl --no-deps +pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" + +python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py" -python3 $TE_PATH/tests/jax/test_sanity_import.py +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index ff7527841a..aa829cd5d2 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -2,29 +2,46 @@ # # See LICENSE for license information. +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + set -x : ${TE_PATH:=/opt/transformerengine} -pip3 install pytest==8.2.1 - -FAIL=0 - -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1 -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 -NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1 -NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 - -exit $FAIL +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" + +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index 5f583af31e..ffd5ca2909 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -2,34 +2,53 @@ # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : "${TE_PATH:=/opt/transformerengine}" -pip3 install wheel +pip3 install wheel || error_exit "Failed to install wheel" cd $TE_PATH -pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch +pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-torch" VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel -wheel unpack dist/* +NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/* || error_exit "Failed to unpack dist/*" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" +rm dist/*.whl || error_exit "Failed to remove dist/*.whl" +mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" cd transformer_engine/pytorch -NVTE_RELEASE_BUILD=1 python3 setup.py sdist +NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" -pip3 install dist/* +pip3 install dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip3 install dist/*.whl --no-deps +pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" + +python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" -python3 $TE_PATH/tests/pytorch/test_sanity_import.py +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index e47aa15fbd..96c5949a99 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -6,4 +6,4 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 597551abfe..5776734c3b 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -2,17 +2,34 @@ # # See LICENSE for license information. -: ${TE_PATH:=/opt/transformerengine} +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} -pip3 install pytest==8.2.1 +RET=0 +FAILED_CASES="" + +: ${TE_PATH:=/opt/transformerengine} -FAIL=0 +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1 -python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1 -# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || FAIL=1 ### TODO Debug UB support with te.Sequential -python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1 +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential +python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" -exit $FAIL +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index f57d055db5..ea4f47a9a4 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -35,6 +35,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py done From d35d00c11bc7abd61609bc4608c59375dfe17df3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Mar 2025 22:59:14 +0000 Subject: [PATCH 235/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- qa/L0_pytorch_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index f6a52fca44..732f0a16d1 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -44,4 +44,4 @@ if [ "$RET" -ne 0 ]; then exit 1 fi echo "All tests passed" -exit 0 \ No newline at end of file +exit 0 From 5da6e91b3061bad7bdd50e714cc930241bad960d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 17 Mar 2025 14:42:54 -0700 Subject: [PATCH 236/239] add default values to IP and assertion to graph.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/dot_product_attention/inference.py | 6 +++--- transformer_engine/pytorch/graph.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/dot_product_attention/inference.py b/transformer_engine/pytorch/dot_product_attention/inference.py index 46e961b381..c28b2013af 100644 --- a/transformer_engine/pytorch/dot_product_attention/inference.py +++ b/transformer_engine/pytorch/dot_product_attention/inference.py @@ -128,9 +128,9 @@ def __init__( self, max_batch_size: int, max_seqlen_kv: int, - num_heads_kv: int, - head_dim_k: int, - dtype: torch.dtype, + num_heads_kv: int = 16, + head_dim_k: int = 64, + dtype: torch.dtype = torch.bfloat16, head_dim_v: int = None, is_paged: bool = False, total_num_pages: int = None, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 827d43196d..0a97e517a0 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -93,6 +93,8 @@ def _make_graphed_callables( # Check training/inference is_training = all(c.training for c in callables) + if not is_training and any(c.training for c in callables): + assert False, "make_graphed_callables only supports when modules are all in training or all in inference mode." # Check sizes of args if _order is None: From 666f771dbf3f95710d63bb4d13f6d1e204a82f2a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:05:37 -0700 Subject: [PATCH 237/239] add more comments in attention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3a877a8d81..afb6b92f04 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5965,6 +5965,8 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" + # convert top-left causal to bottom-right causal due to KV caching + # users can still use the same attention mask for inference as for training assert "padding" in attn_mask_type, "KV caching requires padding mask!" if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" @@ -5979,6 +5981,7 @@ def forward( for x in [query_layer, key_layer, value_layer] ] + # get full K/V tensors from cache and adjust cu_seqlens, qkv_format based on the cache ( key_layer, value_layer, @@ -5995,7 +5998,7 @@ def forward( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None - # get qkv_layout + # get qkv's memory layout if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): ( qkv_layout, @@ -6027,7 +6030,7 @@ def forward( inference_params=inference_params, ) - # adjust max_seqlen and cu_seqlens + # adjust max_seqlen and cu_seqlens for CP cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -6898,7 +6901,7 @@ def forward( ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" # ================================================= - # Pre-allocate memory for key-values for inference + # Pre-allocate memory for key-value cache for inference # ================================================= if ( From 22f79f81b36d10178aec604e04db5856e036a62f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:06:10 -0700 Subject: [PATCH 238/239] use custom_cache_manager instead of cache_manager Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../pytorch/dot_product_attention/inference.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/dot_product_attention/inference.py b/transformer_engine/pytorch/dot_product_attention/inference.py index c28b2013af..ac6e6de896 100644 --- a/transformer_engine/pytorch/dot_product_attention/inference.py +++ b/transformer_engine/pytorch/dot_product_attention/inference.py @@ -120,7 +120,7 @@ class DotProductAttention: Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. qkv_format: str, default = "bshd" Format of the incoming query/key/value tensors in current iteration - cache_manager: KVCacheManager, default = None + custom_cache_manager: KVCacheManager, default = None Custom cache manager, with KVCacheManager as the base class. """ @@ -137,7 +137,7 @@ def __init__( page_size: int = None, max_ctx_len: int = None, qkv_format: str = "bshd", - cache_manager: KVCacheManager = None, + custom_cache_manager: KVCacheManager = None, ): self.max_batch_size = max_batch_size self.max_seqlen_kv = max_seqlen_kv @@ -148,8 +148,8 @@ def __init__( self.is_paged = is_paged if not self.is_paged: - cls = cache_manager if cache_manager is not None else NonPagedKVCacheManager - self.cache_manager = cls( + cache_manager = custom_cache_manager if custom_cache_manager is not None else NonPagedKVCacheManager + self.cache_manager = cache_manager( max_batch_size=self.max_batch_size, max_seqlen=self.max_seqlen_kv, num_heads=self.num_heads_kv, @@ -169,8 +169,8 @@ def __init__( ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." self.total_num_pages = total_num_pages - cls = cache_manager if cache_manager is not None else PagedKVCacheManager - self.cache_manager = cls( + cache_manager = custom_cache_manager if custom_cache_manager is not None else PagedKVCacheManager + self.cache_manager = cache_manager( total_num_pages=self.total_num_pages, page_size=self.page_size, num_heads=self.num_heads_kv, From cfd30cfdbd6002b4365db72a3fdbfcce60ceb7f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Mar 2025 23:06:49 +0000 Subject: [PATCH 239/239] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/dot_product_attention/inference.py | 8 ++++++-- transformer_engine/pytorch/graph.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/dot_product_attention/inference.py b/transformer_engine/pytorch/dot_product_attention/inference.py index ac6e6de896..ae220225e8 100644 --- a/transformer_engine/pytorch/dot_product_attention/inference.py +++ b/transformer_engine/pytorch/dot_product_attention/inference.py @@ -148,7 +148,9 @@ def __init__( self.is_paged = is_paged if not self.is_paged: - cache_manager = custom_cache_manager if custom_cache_manager is not None else NonPagedKVCacheManager + cache_manager = ( + custom_cache_manager if custom_cache_manager is not None else NonPagedKVCacheManager + ) self.cache_manager = cache_manager( max_batch_size=self.max_batch_size, max_seqlen=self.max_seqlen_kv, @@ -169,7 +171,9 @@ def __init__( ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." self.total_num_pages = total_num_pages - cache_manager = custom_cache_manager if custom_cache_manager is not None else PagedKVCacheManager + cache_manager = ( + custom_cache_manager if custom_cache_manager is not None else PagedKVCacheManager + ) self.cache_manager = cache_manager( total_num_pages=self.total_num_pages, page_size=self.page_size, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 0a97e517a0..0479aebb4d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -94,7 +94,10 @@ def _make_graphed_callables( # Check training/inference is_training = all(c.training for c in callables) if not is_training and any(c.training for c in callables): - assert False, "make_graphed_callables only supports when modules are all in training or all in inference mode." + assert False, ( + "make_graphed_callables only supports when modules are all in training or all in" + " inference mode." + ) # Check sizes of args if _order is None: